Hello all, My name is Wolfgang Black - I'm on the ML/AI team here at Civitai. Welcome to the second installment of Machine Learning Projects: Classifying Content into Ratings. In this Series, I'll be discussing how we've tackled the problem of trying to classify images and other media into movie-like ratings for our users to browse.
Despite this article detailing the work we've done for media classification, I wouldn't call this a solved problem! I've talked with other machine learning (ML) scientists at various social media companies who have expressed frustrations at existing ML solutions. The truth of the matter is, human moderators are needed across all levels and forms of media. Even at Civitai, we're still demoing this work in the backend as we try to tune our models to achieve better performance for labeling media. This series of articles is to share insight with our community our efforts at Civitai and their results, especially since the deployment of these solutions may very well effect the site! In the last article, I had some great conversations and insights from our readers, so I invite that again below. Have ideas on how to tackle this problem? Want to try it on your own? Check out the dataset and engage in the comments, lets discuss!
This problem really lends itself to a multimodal solution. The majority of images shared on Civitai were either generated with our onsite generator or uploaded to a users profile after they generate the image locally. In most cases, users share the generation information - specifically the prompt used in generation. We also use ML models, like WDTagger, to add tags to our media. Because of this we have two types of text data and the image data itself. To tackle this problem, we tried two main solutions: a mixture of 'experts' and a multimodal model. These models are composite solutions which take in the outputs or the hidden layers of other ML models. To create those solutions, we needed to do some traditional deep learning (DL) using our single modalities. In this article, I'll cover our efforts in Computer Vision (CV) - If you're interested in our efforts in Natural Language Processing (NLP) and our mixture models, hang tight for article 3.
This article will be long. I want to cover the models we explored, their architectures, and how these models performed. I'll try to include some images and tables, and I'll also link references when important. This isn't meant to be super instructive or to help people unfamiliar with the models master the concepts. It's more a project roadmap of how I tackled this problem. I'm always happy to discuss in the comments if there are questions. Before we get into the models though, lets remind ourselves what we're talking about at a high level.
What Are We Trying to Solve?
At Civitai we receive hundreds of thousands of images from either our generation pipeline, where users experiment with different checkpoints and LoRAs, or ingestion from users sharing their locally generated images. We also receive a ton of data from users in training datasets, which users can use to create their own LoRAs onsite. When users want to make this data public and share with the community, we at Civitai have to make sure we can serve those images up to the appropriate audience. One way we do this is with image ratings that mimic movie ratings. Users are able to control the nature of the content they’ll see in their image by toggling on or off specific ratings.
To keep our users browsing experience pleasant, we want to make sure our ratings are accurate and representative of the content. Currently we use a system of tags, converting the tags to some numeric value, and then determining a max allowable threshold per rating. If tags push the numeric value over a specific rating threshold, the rating graduates to a more restrictive level. However, we’re interested in trying to see if we can’t run a smaller or more simple architecture than a tagger to determine the rating. As such, I’ve tried a few different architectures, modalities, and strategies to build an end-to-end pipeline to correctly classify images with the movie ratings
Its important to remember, we can always try Vision Language Models (VLMs) or some advanced API - but what are the costs associated with hitting ChatGPT with the 1.5 million images generated daily? What are the computational requirements and latency inherent in the use of larger open source VLMs? As amazing as these large model systems are, can we provide support to our moderator teams with smaller more cost effective models?
To approach this problem, we built mixture models which contain an odd number of single modalities and implement a voter layer. The single modality models each individually classify either the image or text. The text could be just the prompt or the prompt with the ML Tags. The models would classify their assigned modality and pass their prediction to the voter, which would select the most voted upon index. If there was a tie the more conservative, that is higher nsfw level, of the tie would be selected. We also experimented with multimodal models, which would take the pre-logit layers from different ML models and concatenate the outputs. These outputs would become the inputs of a simple Multilayer Perceptron (MLP). This MLP would then output logits, which we'd select the highest logit to determine which classification of the media. Readers should know that the multimodal model requires more training, as each individual model was trained for our task and then the MLP was also further trained.
This article focuses on the single modality CV models, their training, and performance. The next article in the series will cover the mixtures and the multimodal models.
The Computer Vision Models
Our first modality is the image itself. While there are a number of advanced DL architectures out there, I tend to start most projects trying to create a simple explainable base model. In this case, I chose to explore utilizing ResNets. In addition to ResNets a member of our community had some experience with You-Only-Look-Once (YOLO) models, and so they suggested trying out a few different types of YOLO. As the dataset grew, I also wanted to try to find the line between using traditional DL models and trying out transformers. Research has shown transformers can actually generalize well with less data than a traditional DL approach, and so I also explored finetuning a Vision Transformer (ViT). In this section, I'll cover a quick background of each model, their main architecture, and then how well each model performed on the most recent dataset.
Residual (Neural) Networks
Residual (Neural) Networks, or ResNets for short, are notable for their ability to form very deep networks without the problem of vanishing or exploding gradients. ResNets avoid vanishing gradients because they learn the residual function instead of the true underlying mapping of the inputs to the output. They utilize skip connections to bypass layers, adding the inputs of layers to the output of deeper layers. This means that even if the gradient would vanish or change detrimentally, the gradient would flow through the model and naturally be preserved.
Architecture
A residual block can be seen below. In this figure, we see the input to the residual block being processed through some layers, such as convolutional layers, dropout, etc. We can have activations throughout the residual block; however, at the end of the residual block, before the final activation, we add the inputs of the residual block to the outputs of the final layer. This sum is then passed through a final activation to the next section of the neural network. Typically, ReLU activations are used.
When ResNets were first reported on, back in 2015 they showed state-of-the-art (SOTA) performance on image classification tasks beating out the standard methods at the time. Their real strengths were utilizing the skip connects to prevent gradient issues while building deeper networks capable of learning more complex features for image tasks. Today, they're light weight easy to train networks that have easy customizable implementations in tensorflow, pytorch, and the transformers libraries.
For our first experiments, we trained a resnet18 and a resnet50 model, using the data with a 80/20 train/test split. We also experimented with k-folds cross-validation, trying to determine various slices of the data to allow for better performance on a validation dataset. Readers should note that for all trainings we balanced the classes using down sampling to the 'minority' class - R, which had around 5k examples. The results of these models can be seen in the next section. These results served as a north star for the project - they gave us confidence that this problem COULD be simplified with machine learning, but that we needed more data and potentially more powerful architectures.
You Only Look Once - A community recommendation
During this project, we partnered with community member Pitpe11 to test another computer vision model—YOLO. YOLO is a popular real-time object detection model known for its speed and accuracy. Unlike traditional object detection models that use a sliding window, YOLO treats object detection as a regression problem. It uses bounding boxes within images to identify regions containing specific class objects and learns the features associated with each object within a bounding box regardless of size and location within the images.
To achieve this, YOLO divides the image into a grid. Each grid cell predicts a fixed number of bounding boxes, their confidence scores, and class probabilities. The confidence score reflects the likelihood that a bounding box contains an object and how accurate the bounding box coordinates are. The grids with high confidence scores are then filtered and used to combine overlapping boxes to determine the final object regions. YOLO identifies regions within images that may contain different objects/classes and centers the boxes on specific objects. Without getting into the mathematics, the important idea here is that we use the area of overlap and the area of union to refine the bounding boxes. An example of this is shown in the image below:
The library Ultralytics provides templates on how to train and deploy a YOLO model. For this experiment, Pitpe11 provided us with a YOLO model trained for detection but configured to detect only one 'object' per class. This approach for classification is interesting because it sweeps over the entire grid and examines each grid for context that could be considered in a specific category. Unfortunately, due to the integrated nature of Ultralytics' implementation, extracting the YOLO detection model and using it outside of the Ultralytics library can be challenging. Later in the results section, we'll refer to this model as Yolo-Det.
The community member also suggested looking at a YOLO classifier model, which is similar but falls under the classification category. This model was easier to train as well as to extract and use alongside other models. This model will be denoted as Yolo-CLS below.
Interestingly, the detection model performed better than the classifier. The main difference between the using detection vs classification is that in detection their are custom detection layers which explore the grid and assign the probabilities of each class to each grid - this may serve to create more unique feature mappings and help to capture odd edge cases. In classification these detection layers are replaced with a fully connected layer, more standard in traditional classification problems. If the grid is what truly made the difference, perhaps we can take advantage of some new algorithms that also divide images into different regions for classification.
A modeling aside: Transformers
In 2017, Google published a groundbreaking machine learning architecture in the paper "Attention is All You Need". This architecture, known as the Transformer, introduced a method called Attention, which allows the model to consider specific parts of the input as either important or ignorable. This was particularly beneficial for natural language processing (NLP) problems, which often involve inputs of varying lengths.
Standard text embedding models for generative AI may only consider the first 75 tokens of the prompt or apply some transformation on the embeddings to handle the full prompt. While modern Large Language Models (LLMs) now support large context windows, early NLP models struggled with varying context window sizes. When padding was used, these models could not distinguish between actual data and padding, treating the padded embeddings as meaningful input.
The introduction of Attention addressed this issue by enabling researchers to indicate that padding tokens should be ignored. By applying zeros in the attention layer for padding tokens, the model effectively disregards them, simplifying the mathematical operations in the embedded space as the embeddings pass through the model. This innovation significantly improved the handling of varying input lengths in NLP models.
Additionally, the attention layer can also be a value between 0-1, indicating various levels of relation between tokens. In this way, the attention layer allows the model to consider various tokens across the entire sequence, regardless of position or directionality. This can allow for the model to learn deeper connections between tokens within sequences, leading to more advanced feature understandings and better performance.
Vision Transformers
What does attention have to do with vision? Well, we already revealed that the architecture that utilizes attention is called a Transformer. And in 2020 researchers applied this concept to CV. This led to the development of the first Vision Transformer (ViT) architecture for image classification. The key idea was to divide image inputs into a series of patches (or a grid) and then linearly project the flattened patches into an input sequence. This input sequence was then processed through an embedding layer whose output mirrored the dimensionality of NLP transformers, and so researchers were able to use the basic transformer architecture with a MLP classification head for image classification. By using a transformer with attention for vision, researchers were able to achieve SOTA results with much smaller networks, leading to competitive and computationally efficient models for image classification.
Architecture
To better understand how ViT works I've included an illustration below. Essentially, the image is divided into grids. Each grid cell is converted into a numpy array. After flattening these arrays into 1d vectors we concatenate them into a single vector of length X
. For example, if we had a grid of 32 cells, we'd concatenate them into an array of 1x32X. This array then gets passed through an embedding layer, to achieve a context reasonable to input into the transformer layer. The linearly concatenated, learned embeddings then enter the transformer encoder, utilizing both skip connections (similar to ResNet) and attention mechanisms.
So, how does attention work in vision? These kind of transformers also behave as 'encoders'. This is because we use attention to encode information about each patch relative to the others. The attention mechanism helps the model understand how each patch is related to every other patch. This calculation is done over each patch multiple times, creating an encoded output that is then fed into the MLP or classification head. The encoded output contains information about the features of each patch and their relevance to other patches. This output serves as the features of the image and are fed into the MLP classification head.
In this way, ViTs can better learn low-level shapes and higher-level patterns that make up various concepts within the images. It's important to note that the MLP is a basic type of neural network consisting mainly of fully connected layers, but it may include dropout or other minor layers to improve performance. This means that the ViT primarily serves as a feature processor/encoder, while the main classification scores come from the MLP classification head.
Results
Above we covered the various architectures we utilized for CV in this work. We also experimented with different depths for the ResNets, finetuning on top of various pre-training weights, different training techniques like early stopping and focusing on specific metrics over others. We also explored different slices of the data through shuffling, different data augmentations, and techniques like cross validation and down-sampling various classes. In this section present the results of these efforts by reporting metrics like class accuracy and f1 score. We'll also look at Down Class Misclassification (DCM), a significant metric for this project as it highlights the misclassification of images with significant differences in content rating like an X
rated image being scored as PG13
.
To get these performance metrics we have applied a test set. Unlike the training dataset, this test set is some 12k images - the bulk of which have not been publicly shared. These images have all been manually reviewed by either the head of our moderation team, our steadfast CEO, myself, or are images that did not make it into any of the training sets.
The metrics we'll report out are derived metrics based on four parameters: True Positive (TP), False Positive (FP), True Negative (TN) and False Negative (FN). These concepts make up the base confusion matrix. To explain these concepts - lets consider a binary classifier which predicts class 0 or class 1. If the model predicts class 0 correctly, then we can say the model has made a TP prediction for class 0 - and a TN for class 1. If it predicts class 1 instead, it's a FN for class 0 and a FP for class 1. Thus every prediction falls into one of these categories across all classes. However, rather than considering each parameter for all classes we use summarized metrics. Below is a table with these summarized metrics and their definitions.
Below we see the first table of results, Accuracy (Acc). Notable, as the architectures get more complex the models generally perform better. We see that the ResNet18 architecture generally has the worse performance. We also see that the two ResNet50s are very similar, but that ResNet50 with Cross Validation (ResNet50-CV) performs slightly better on PG13. The best performing model for accuracy is the YOLO model trained for single detection (Yolo-Det), which is the first model shared by Pitpe11. In fact, it does perform better than the YOLO model trained for traditional classification (Yolo-CLS) on our test set. A side note: While Yolo-CLS was trained on the same ~21k images as the other models, Pitpe11 provided the weights for Yolo-Det. This means this model was trained differently than the others, potentially on data outside of our original dataset. This shows the importance of understanding the data, its distributions, and highlights that this problem can still greatly benefit from data exploration. ML is meant to be iterative in development.
We also consider the class f1 score to understand how the model's balance handling predictions. In this metric we see the performance gap between architectures narrows. The Resnet50 scores very similarly to the ResNet18 despite having +5 points in most classes for accuracy. In both accuracy and f1, ResNet50 performs worse on PG13. Yolo-Det is again the best performing model across all metrics, with much higher performance in PG and PG13. ViT outperforms all ResNets across all classes and the Yolo-CLS for the NSFW classes. All in all, we can see the effect of more advanced architectures performing better on the datasets than the original base case classes.
From above we can identify the best performing CV models are the ViT, the Yolo-Det, and the Yolo-CLS models. The final table we'll share here will consider these three models and their DCMs. We're specifically concerned about NSFW images being labeled as SFW. We also include XXX
and X
to R
to help get a feel of the spread across the NSFW classes. The table below shows raw counts, since we can consider these as examples that may require moderation. Here we see another interesting pattern emerge from the Yolo-Det and ViT models. The ViT performs less total DCMs than the YOLOs but has slightly more DCMs when we neglect R. That is to say, the best performing model for least DCMs from [XXX, X, R]
to [PG13, PG]
is Yolo-Det. The difference here is Yolo-Det has a hard time distinguishing the lines between X
and R
and a slightly harder time with R
and PG13
whereas ViT makes slightly more mistakes in general. Yolo-CLS is inferior to the others in all cases, but does slightly better than ViT in X-PG
which to be fair is a critical class we need to consider.
Conclusion
In this article we explored the various CV architectures we used to try to address this classification problem. We experimented with ResNets of various sizes and techniques like cross-validation, we compared different YOLO architectures with different datasets, and utilized a ViT model. The main takeaways can be summarized as:
Our dataset is large and varied enough to take advantage of complex DL algorithms like ViT
The Yolo-Det model is superior to the other models, despite its simple architecture. However it’s difficult to utilize outside of a single model or the mixture as the Detection Model Class used in Ultralytics makes weight extraction a headache
The ViT model performs very well on DCM and f1, but definitely has room for improvement
No single CV model presented here can truly solve the problem. But a lot of data is being left on the table. Could incorporating the text modality help?
In the next article we'll look at the NLP models, mixture models, and their performance in media classification. We’ll see similar performance from the single modality text models (albeit a little better than the image models) and discover interesting findings from the mixture models.
While I organize my thoughts for the next article, please share your experience with CV modeling! Any thoughts or questions about this work, or experiences you want to share? I specifically didn’t delve into the mathematics or do a deep dive into the models above, as this article is already over 4,000 words! But if the community is interested I’d be happy to share code, write up mathematical deep dives, and provide more technical depth!
As this series continues, voice your opinions, thoughts, and ideas! Together, we can solve even the most open questions in ML and AI! Lets shape the future of safe and enjoyable content across Civitai and the Gen-AI community!
Special Thanks
To Pitpe11 for their collaboration on this work, specifically in their work on the Yolo-Det model