Sign In

The Gory Details of Finetuning SDXL for 40M samples

The Gory Details of Finetuning SDXL for 40M samples

Details on how the big SDXL finetunes are trained is scarce, so just like with version 1 of my model bigASP, I'm sharing all the details here to help the community. This is going to be long, because I'm dumping as much about my experience as I can. I hope it helps someone out there.

Overview

Version 2 was trained on 6,716,761 images, all with resolutions exceeding 1MP, and sourced as originals whenever possible, to reduce compression artifacts to a minimum. Each image is about 1MB on disk, making the dataset about 1TB per million images.

Prior to training, every image goes through the following pipeline:

* CLIP-B/32 embeddings, which get saved to the database and used for later stages of the pipeline. This is also the stage where images that cannot be loaded are filtered out.

* A custom trained quality model rates each image from 0 to 9, inclusive.

* JoyTag is used to generate tags for each image.

* JoyCaption Alpha Two is used to generate captions for each image.

* OWLv2 with the prompt "a watermark" is used to detect watermarks in the images.

* VAE encoding, saving the pre-encoded latents with gzip compression to disk.

Training was done using a custom training script, which uses the diffusers library to handle the model itself. This has pros and cons versus using a more established training script like kohya. It allows me to fully understand all the inner mechanics and implement any tweaks I want. The downside is that a lot of time has to be spent debugging subtle issues that crop up, which often results in expensive mistakes. For me, those mistakes are just the cost of learning and the trade off is worth it. But I by no means recommend this form of masochism.

The Quality Model

Scoring all images in the dataset from 0 to 9 allows two things. First, all images scored at 0 are completely dropped from training. In my case, I specifically have to filter out things like ads, video preview thumbnails, etc from my dataset, which I ensure get sorted into the 0 bin. Second, during training score tags are prepended to the image prompts. Later, users can use these score tags to guide the quality of their generations. This, theoretically, allows the model to still learn from "bad images" in its training set, while retaining high quality outputs during inference. This particular method of using score tags was pioneered by the incredible Pony Diffusion models.

The model that judges the quality of images is built in two phases. First, I manually collect a dataset of head-to-head image comparisons. This is a dataset where each entry is two images, and a value indicating which image is "better" than the other. I built this dataset by rating 2000 images myself. An image is considered better as agnostically as possible. For example, a color photo isn't necessarily "better" than a monochrome image, even though color photos would typically be more popular. Rather, each image is considered based on its merit within its specific style and subject. This helps prevent the scoring system from biasing the model towards specific kinds of generations, and instead keeps it focused on just affecting the quality. I experimented a little with having a well prompted VLM rate the images, and found that the machine ratings matched my own ratings 83% of the time. That's probably good enough that machine ratings could be used to build this dataset in the future, or at least provide significant augmentation to it. For this iteration, I settled on doing "human in the loop" ratings, where the machine rating, as well as an explanation from the VLM about why it rated the images the way it did, was provided to me as a reference and I provided the final rating. I found the biggest failing of the VLMs was in judging compression artifacts and overall "sharpness" of the images.

This head-to-head dataset was then used to train a model to predict the "better" image in each pair. I used the CLIP-B/32 embeddings from earlier in the pipeline, and trained a small classifier head on top. This works well to train a model on such a small amount of data. The dataset is augmented slightly by adding corrupted pairs of images. Images are corrupted randomly using compression or blur, and a rating is added to the dataset between the original image and the corrupted image, with the corrupted image always losing. This helps the model learn to detect compression artifacts and other basic quality issues. After training, this Classifier model reaches an accuracy of 90% on the validation set.

Now for the second phase. An arena of 8,192 random images are pulled from the larger corpus. Using the trained Classifier model, pairs of images compete head-to-head in the "arena" and an ELO ranking is established. There are 8,192 "rounds" in this "competition", with each round comparing all 8,192 images against random competitors.

The ELO ratings are then binned into 10 bins, establishing the 0-9 quality rating of each image in this arena. A second model is trained using these established ratings, very similar to before by using the CLIP-B/32 embeddings and training a classifier head on top. After training, this model achieves an accuracy of 54% on the validation set. While this might seem quite low, its task is significantly harder than the Classifier model from the first stage, having to predict which of 10 bins an image belongs to. Ranking an image as "8" when it is actually a "7" is considered a failure, even though it is quite close. I should probably have a better accuracy metric here...

This final "Ranking" model can now be used to rate the larger dataset. I do a small set of images and visualize all the rankings to ensure the model is working as expected. 10 images in each rank, organized into a table with one rank per row. This lets me visually verify that there is an overall "gradient" from rank 0 to rank 9, and that the model is being agnostic in its rankings.

So, why all this hubbub for just a quality model? Why not just collect a dataset of humans rating images 1-10 and train a model directly off that? Why use ELO?

First, head-to-head ratings are far easier to judge for humans. Just imagine how difficult it would be to assess an image, completely on its own, and assign one of ten buckets to put it in. It's a very difficult task, and humans are very bad at it empirically. So it makes more sense for our source dataset of ratings to be head-to-head, and we need to figure out a way to train a model that can output a 0-9 rating from that.

In an ideal world, I would have the ELO arena be based on all human ratings. i.e. grab 8k images, put them into an arena, and compare them in 8k rounds. But that's over 64 million comparisons, which just isn't feasible. Hence the use of a two stage system where we train and use a Classifier model to do the arena comparisons for us.

So, why ELO? A simpler approach is to just use the Classifier model to simply sort 8k images from best to worst, and bin those into 10 bins of 800 images each. But that introduces an inherent bias. Namely, that each of those bins are equally likely. In reality, it's more likely that the quality of a given image in the dataset follows a gaussian or similar non-uniform distribution. ELO is a more neutral way to stratify the images, so that when we bin them based on their ELO ranking, we're more likely to get a distribution that reflects the true distribution of image quality in the dataset.

With all of that done, and all images rated, score tags can be added to the prompts used during the training of the diffusion model. During training, the data pipeline gets the image's rating. From this it can encode all possible applicable score tags for that image. For example, if the image has a rating of 3, all possible score tags are: score_3, score_1_up, score_2_up, score_3_up. It randomly picks some of these tags to add to the image's prompt. Usually it just picks one, but sometimes two or three, to help mimic how users usually just use one score tag, but sometimes more. These score tags are prepended to the prompt. The underscores are randomly changed to be spaces, to help the model learn that "score 1" and "score_1" are the same thing. Randomly, commas or spaces are used to separate the score tags. Finally, 10% of the time, the score tags are dropped entirely. This keeps the model flexible, so that users don't have to use score tags during inference.

JoyTag

JoyTag is used to generate tags for all the images in the dataset. These tags are saved to the database and used during training. During training, a somewhat complex system is used to randomly select a subset of an image's tags and form them into a prompt. I documented this selection process in the details for Version 1, so definitely check that. But, in short, a random number of tags are randomly picked, joined using random separators, with random underscore dropping, and randomly swapping tags using their known aliases. Importantly, for Version 2, a purely tag based prompt is only used 10% of the time during training. The rest of the time, the image's caption is used.

Captioning

An early version of JoyCaption, Alpha Two, was used to generate captions for bigASP version 2. It is used in random modes to generate a great variety in the kinds of captions the diffusion model will see during training. First, a number of words is picked from a normal distribution centered around 45 words, with a standard deviation of 30 words.

Then, the caption type is picked: 60% of the time it is "Descriptive", 20% of the time it is "Training Prompt", 10% of the time it is "MidJourney", and 10% of the time it is "Descriptive (Informal)". Descriptive captions are straightforward descriptions of the image. They're the most stable mode of JoyCaption Alpha Two, which is why I weighted them so heavily. However they are very formal, and awkward for users to actually write when generating images. MidJourney and Training Prompt style captions mimic what users actually write when generating images. They consist of mixtures of natural language describing what the user wants, tags, sentence fragments, etc. These modes, however, are a bit unstable in Alpha Two, so I had to use them sparingly. I also randomly add "Include whether the image is sfw, suggestive, or nsfw." to JoyCaption's prompt 25% of the time, since JoyCaption currently doesn't include that information as often as I would like.

There are many ways to prompt JoyCaption Alpha Two, so there's lots to play with here, but I wanted to keep things straightforward and play to its current strengths, even though I'm sure I could optimize this quite a bit more.

At this point, the captions could be used directly as the prompts during training (with the score tags prepended). However, there are a couple of specific things about the early version of JoyCaption that I absolutely wanted to fix, since they could hinder bigASP's performance. Training Prompt and MidJourney modes occasionally glitch out into a repetition loop; it uses a lot of vacuous stuff like "this image is a" or "in this image there is"; it doesn't use informal or vulgar words as often as I would like; its watermark detection accuracy isn't great; it sometimes uses ambiguous language; and I need to add the image sources to the captions.

To fix these issues at the scale of 6.7 million images, I trained and then used a sequence of three finetuned Llama 3.1 8B models to make focussed edits to the captions. The first model is multi-purpose: fixing the glitches, swapping in synonyms, removing ambiguity, and removing the fluff like "this image is." The second model fixes up the mentioning of watermarks, based on the OWLv2 detections. If there's a watermark, it ensures that it is always mentioned. If there isn't a watermark, it either removes the mention or changes it to "no watermark." This is absolutely critical to ensure that during inference the diffusion model never generates watermarks unless explictly asked to. The third model adds the image source to the caption, if it is known. This way, users can prompt for sources.

Training these models is fairly straightforward. The first step is collecting a small set of about 200 examples where I manually edit the captions to fix the issues I mentioned above. To help ensure a great variety in the way the captions get editted, reducing the likelihood that I introduce some bias, I employed zero-shotting with existing LLMs. While all existing LLMs are actually quite bad at making the edits I wanted, with a rather long and carefully crafted prompt I could get some of them to do okay. And importantly, they act as a "third party" editting the captions to help break my biases. I did another human-in-the-loop style of data collection here, with the LLMs making suggestions and me either fixing their mistakes, or just editting it from scratch. Once 200 examples had been collected, I had enough data to do an initial fine-tune of Llama 3.1 8B. Unsloth makes this quite easy, and I just train a small LORA on top. Once this initial model is trained, I then swap it in instead of the other LLMs from before, and collect more examples using human-in-the-loop while also assessing the performance of the model. Different tasks required different amounts of data, but everything was between about 400 to 800 examples for the final fine-tune.

Settings here were very standard. Lora rank 16, alpha 16, no dropout, target all the things, no bias, batch size 64, 160 warmup samples, 3200 training samples, 1e-4 learning rate.

I must say, 400 is a very small number of examples, and Llama 3.1 8B fine-tunes beautifully from such a small dataset. I was very impressed.

This process was repeated for each model I needed, each in sequence consuming the editted captions from the previous model. Which brings me to the gargantuan task of actually running these models on 6.7 million captions. Naively using HuggingFace transformers inference, even with torch.compile or unsloth, was going to take 7 days per model on my local machine. Which meant 3 weeks to get through all three models. Luckily, I gave vLLM a try, and, holy moly! vLLM was able to achieve enough throughput to do the whole dataset in 48 hours! And with some optimization to maximize utilization I was able to get it down to 30 hours. Absolutely incredible.

After all of these edit passes, the captions were in their final state for training.

VAE encoding

This step is quite straightforward, just running all of the images through the SDXL vae and saving the latents to disk. This pre-encode saves VRAM and processing during training, as well as massively shrinks the dataset size. Each image in the dataset is about 1MB, which means the dataset as a whole is nearly 7TB, making it infeasible for me to do training in the cloud where I can utilize larger machines. But once gzipped, the latents are only about 100KB each, 10% the size, dropping it to 725GB for the whole dataset. Much more manageable. (Note: I tried zstandard to see if it could compress further, but it resulted in worse compression ratios even at higher settings. Need to investigate.)

## Aspect Ratio Bucketing and more

Just like v1 and many other models, I used aspect ratio bucketing so that different aspect ratios could be fed to the model. This is documented to death, so I won't go into any detail here. The only thing different, and new to version 2, is that I also bucketed based on prompt length.

One issue I noted while training v1 is that the majority of batches had a mismatched number of prompt chunks. For those not familiar, to handle prompts longer than the limit of the text encoder (75 tokens), NovelAI invented a technique which pretty much everyone has implemented into both their training scripts and inference UIs. The prompts longer than 75 tokens get split into "chunks", where each chunk is 75 tokens (or less). These chunks are encoded separately by the text encoder, and then the embeddings all get concatenated together, extending the UNET's cross attention.

In a batch if one image has only 1 chunk, and another has 2 chunks, they have to be padded out to the same, so the first image gets 1 extra chunk of pure padding appended. This isn't necessarily bad; the unet just ignores the padding. But the issue I ran into is that at larger mini-batch sizes (16 in my case), the majority of batches end up with different numbers of chunks, by sheer probability, and so almost all batches that the model would see during training were 2 or 3 chunks, and lots of padding. For one thing, this is inefficient, since more chunks require more compute. Second, I'm not sure what effect this might have on the model if it gets used to seeing 2 or 3 chunks during training, but then during inference only gets 1 chunk. Even if there's padding, the model might get numerically used to the number of cross-attention tokens.

To deal with this, during the aspect ratio bucketing phase, I estimate the number of tokens an image's prompt will have, calculate how many chunks it will be, and then bucket based on that as well. While not 100% accurate (due to randomness of length caused by the prepended score tags and such), it makes the distribution of chunks in the batch much more even.

UCG

As always, the prompt is dropped completely by setting it to an empty string some small percentage of the time. 5% in the case of version 2. In contrast to version 1, I elided the code that also randomly set the text embeddings to zero. This random setting of the embeddings to zero stems from Stability's reference training code, but it never made much sense to me since almost no UIs set the conditions like the text conditioning to zero. So I disabled that code completely and just do the traditional setting of the prompt to an empty string 5% of the time.

## Training

Training commenced almost identically to version 1. min-snr loss, fp32 model with AMP, AdamW, 2048 batch size, no EMA, no offset noise, 1e-4 learning rate, 0.1 weight decay, cosine annealing with linear warmup for 100,000 training samples, text encoder 1 training enabled, text encoder 2 kept frozen, min_snr_gamma=5, GradScaler, 0.9 adam beta1, 0.999 adam beta2, 1e-8 adam eps. Everything initialized from SDXL 1.0.

Compared to version 1, I upped the training samples from 30M to 40M. I felt like 30M left the model a little undertrained.

A validation dataset of 2048 images is sliced off the dataset and used to calculate a validation loss throughout training. A stable training loss is also measured at the same time as the validation loss. Stable training loss is similar to validation, except the slice of 2048 images it uses are not excluded from training. One issue with training diffusion models is that their training loss is extremely noisy, so it can be hard to track how well the model is learning the training set. Stable training loss helps because its images are part of the training set, so it's measuring how the model is learning the training set, but they are fixed so the loss is much more stable. By monitoring both the stable training loss and validation loss I can get a good idea of whether A) the model is learning, and B) if the model is overfitting.

Training was done on an 8xH100 sxm5 machine rented in the cloud. Compared to version 1, the iteration speed was a little faster this time, likely due to optimizations in PyTorch and the drivers in the intervening months. 80 images/s. The entire training run took just under 6 days.

Training commenced by spinning up the server, rsync-ing the latents and metadata over, as well as all the training scripts, openning tmux, and starting the run. Everything gets logged to WanDB to help me track the stats, and checkpoints are saved every 500,000 samples. Every so often I rsync the checkpoints to my local machine, as well as upload them to HuggingFace as a backup.

On my local machine I use the checkpoints to generate samples during training. While the validation loss going down is nice to see, actual samples from the model running inference are critical to measuring the tangible performance of the model. I have a set of prompts and fixed seeds that get run through each checkpoint, and everything gets compiled into a table and saved to an HTML file for me to view. That way I can easily compare each prompt as it progresses through training.

Post Mortem (What worked)

The big difference in version 2 is the introduction of captions, instead of just tags. This was unequivocally a success, bringing a whole range of new promptable concepts to the model. It also makes the model significantly easier for users.

I'm overall happy with how JoyCaption Alpha Two performed here. As JoyCaption progresses toward its 1.0 release I plan to get it to a point where it can be used directly in the training pipeline, without the need for all these Llama 3.1 8B models to fix up the captions.

bigASP v2 adheres fairly well to prompts. Not at FLUX or DALLE 3 levels by any means, but for just a single developer working on this, I'm happy with the results. As JoyCaption's accuracy improves, I expect prompt adherence to improve as well. And of course furture versions of bigASP are likely to use more advanced models like Flux as the base.

Increasing the training length to 40M I think was a good move. Based on the sample images generated during training, the model did a lot of "tightening up" in the later part of training, if that makes sense. I know that models like Pony XL were trained for a multiple or more of my training size. But this run alone cost about $3,600, so ... it's tough for me to do much more.

The quality model seems improved, based on what I'm seeing. The range of "good" quality is much higher now, with score_5 being kind of the cut-off for decent quality. Whereas v1 cut off around 7. To me, that's a good thing, because it expands the range of bigASP's outputs.

Some users don't like using score tags, so dropping them 10% of the time was a good move. Users also report that they can get "better" gens without score tags. That makes sense, because the score tags can limit the model's creativity. But of course not specifying a score tag leads to a much larger range of qualities in the gens, so it's a trade off. I'm glad users now have that choice.

For version 2 I added 2M SFW images to the dataset. The goal was to expand the range of concepts bigASP knows, since NSFW images are often quite limited in what they contain. For example, version 1 had no idea how to draw an ice cream cone. Adding in the SFW data worked out great. Not only is bigASP a good photoreal SFW model now (I've frequently gen'd nature photographs that are extremely hard to discern as AI), but the NSFW side has benefitted greatly as well. Most importantly, NSFW gens with boring backgrounds and flat lighting are a thing of the past!

I also added a lot of male focussed images to the dataset. I've always wanted bigASP to be a model that can generate for all users, and excluding 50% of the population from the training data is just silly. While version 1 definitely had male focussed data, it was not nearly as representative as it should have been. Version 2's data is much better in this regard, and it shows. Male gens are closer than ever to parity with female focussed gens. There's more work yet to do here, but it's getting better.

Post Mortem (What didn't work)

The finetuned llama models for fixing up the captions would themselves very occasionally fail. It's quite rare, maybe 1 in a 1000 captions, but of course it's not ideal. And since they're chained, that increases the error rate. The fix is, of course, to have JoyCaption itself get better at generating the captions I want. So I'll have to wait until I finish work there :p

I think the SFW dataset can be expanded further. It's doing great, but could use more.

I experimented with adding things outside the "photoreal" domain in version 2. One thing I want out of bigASP is the ability to create more stylistic or abstract images. My focus is not necessarily on drawings/anime/etc. There are better models for that. But being able to go more surreal or artsy with the photos would be nice. To that end I injected a small amount of classical art into the dataset, as well as images that look like movie stills. However, neither of these seem to have been learned well in my testing. Version 2 can operate outside of the photoreal domain now, but I want to improve it more here and get it learning more about art and movies, where it can gain lots of styles from.

Generating the captions for the images was a huge bottleneck. I hadn't discovered the insane speed of vLLM at the time, so it took forever to run JoyCaption over all the images. It's possible that I can get JoyCaption working with vLLM (multi-modal models are always tricky), which would likely speed this up considerably.

Post Mortem (What really didn't work)

I'll preface this by saying I'm very happy with version 2. I think it's a huge improvement over version 1, and a great expansion of its capabilities. Its ability to generate fine grained details and realism is even better. As mentioned, I've made some nature photographs that are nearly indistinguishable from real photos. That's crazy for SDXL. Hell, version 2 can even generate text sometimes! Another difficult feat for SDXL.

BUT, and this is the painful part. Version 2 is still ... tempermental at times. We all know how inconsistent SDXL can be. But it feels like bigASP v2 generates mangled corpses far too often. An out of place limb here and there, bad hands, weird faces are all fine, but I'm talking about flesh soup gens. And what really bothers me is that I could maybe dismiss it as SDXL being SDXL. It's an incredible technology, but has its failings. But Pony XL doesn't really have this issue. Not all gens from Pony XL are "great", but body horror is at a much more normal level of occurance there. So there's no reason bigASP shouldn't be able to get basic anatomy right more often.

Frankly, I'm unsure as to why this occurs. One theory is that SDXL is being pushed to its limit. Most prompts involving close-ups work great. And those, intuitively, are "simpler" images. Prompts that zoom out and require more from the image? That's when bigASP drives the struggle bus. 2D art from Pony XL is maybe "simpler" in comparison, so it has less issues, whereas bigASP is asking a lot of SDXL's limited compute capacity. Then again Pony XL has an order of magnitude more concepts and styles to contend with compared to photos, so shrug.

Another theory is that bigASP has almost no bad data in its dataset. That's in contrast to base SDXL. While that's not an issue for LORAs which are only slightly modifying the base model, bigASP is doing heavy modification. That is both its strength and weakness. So during inference, it's possible that bigASP has forgotten what "bad" gens are and thus has difficulty moving away from them using CFG. This would explain why applying Perturbed Attention Guidance to bigASP helps so much. It's a way of artificially generating bad data for the model to move its predictions away from.

Yet another theory is that base SDXL is possibly borked. Nature photography works great way more often than images that include humans. If humans were heavily censored from base SDXL, which isn't unlikely given what we saw from SD 3, it might be crippling SDXL's native ability to generate photorealistic humans in a way that's difficult for bigASP to fix in a fine-tune. Perhaps more training is needed, like on the level of Pony XL? Ugh...

And the final (most probable) theory ... I fucked something up. I've combed the code back and forth and haven't found anything yet. But it's possible there's a subtle issue somewhere. Maybe min-snr loss is problematic and I should have trained with normal loss? I dunno.

While many users are able to deal with this failing of version 2 (with much better success than myself!), and when version 2 hits a good gen it hits, I think it creates a lot of friction for new users of the model. Users should be focussed on how to create the best image for their use case, not on how to avoid the model generating a flesh soup.

Graphs

Wandb run:

https://api.wandb.ai/links/hungerstrike/ula40f97

Validation loss:

Stable loss:

Source code

Source code for the training scripts, Python notebooks, data processing, etc were all provided for version 1: https://github.com/fpgaminer/bigasp-training

I'll update the repo soon with version 2's code. As always, this code is provided for reference only; I don't maintain it as something that's meant to be used by others. But maybe it's helpful for people to see all the mucking about I had to do.

Final Thoughts

I hope all of this is useful to others. I am by no means an expert in any of this; just a hobbyist trying to create cool stuff. But people seemed to like the last time I "dumped" all my experiences, so here it is.

17

Comments