a picture is worth a 1000 words; breaking down Vision Transformers
I’m sure you’ve heard the infamous saying ‘a picture is worth a 1000 words’; this saying implies that complex ideas can be conveyed through a single still image. Despite the popularity of this saying, not many have heard this saying applied in the realm of Machine Learning. However, in the field of Computer Vision, this saying is very relevant. In recent years, a new type of neural network architecture has been gaining popularity in the Computer Vision community: Vision Transformers.
Let’s take a step back, what are Vision Transformers?
Vision Transformers, or ViTs for short, are a type of neural network architecture that use self-attention mechanisms to process images. Similar to traditional Convolutional Neural Networks (CNNs), ViTs take in images as input and output a prediction or a set of predictions. However, unlike CNNs, ViTs do not use convolutional layers to process the images. Instead, they rely on self-attention mechanisms to process the image pixels. This is an extremely rough overview, throughout this article I’ll be providing an in-depth breakdown of Vision Transformers. With that being said, let’s jump into the thick of this article!
understanding ViTs
A couple years back, a bunch of researchers from the Google Brain Research Team wrote a paper called ‘an image is worth 16x16 words: transformers for image recognition at scale’. This paper heavily built on the ‘attention is all you need’ paper written back in 2017 by researchers at the University of Toronto and Google Brain.
To give a quick recap, transformers are a type of neural network architecture used in natural language processing (NLP) tasks, such as machine translation and text generation. It relies on a mechanism called “self-attention” to understand the relationships between different words in a sentence or sequence. The Transformer model works by breaking down an input sequence into individual tokens (words or subwords) and representing them as vectors. These vectors are then processed in parallel through multiple layers of self-attention and feed-forward neural networks.
The self-attention mechanism allows the model to focus on different parts of the input sequence while generating each output word. It assigns weights to each word in the sequence based on its relevance to other words. This means that the model can “attend” to different words and understand their contextual relationships, regardless of their position in the sequence.During the self-attention process, the model calculates attention scores between all pairs of words in the input sequence. These scores determine how much each word should contribute to the representation of other words. The words with higher attention scores have a stronger influence on the final representation.
This model is the model of choice for Natural Language Processing (NLP) tasks, as we’ve seen with chatGPT, google bard, etc. With computer vision, the typical Convolutional Neural Network architecture has been the preferred option as researchers have found it difficult to scale CNN-like architectures with self-attention onto modern hardware accelerators. As such, classical models such as ResNet are still considered state of the art. However, recently a new architecture, called Vision Transformers have been put in the spotlight. They are heavily inspired by the success of the transformer architecture; they essentially apply a transformer directly to images. The image is split into patches and provide the sequence of liner embeddings of these patches as the input to the Transformer. The image patches are treated similarly to how ‘tokens (or words)’ are in the regular transformer architecture. The first ViT model was trained on an image classification task with supervision; the dataset was medium sized as the Vision Transformers do not generalize well on small datasets. To make this more clearer, here’s what the coded implementation would look like:
This seems like a lot, but let me break it down into an analogy: Imagine you’re in a crowded room and you’re trying to focus on one person’s conversation. You can’t listen to everyone at once, so you “pay attention” to the person you’re interested in and tune out the rest. That’s essentially what this “Attention” mechanism does, but with parts of an image instead of people in a room.
The model looks at different parts of the image and decides which parts are important for making its final decision (like identifying if the image is of a cat or a dog). The parts it pays “attention” to get a higher importance score.
This process is done multiple times (once for each “head”) and all the results are combined at the end. This allows the model to focus on different features in the image at the same time (like the shape of the ears, color of the fur, etc.).
Finally, the model uses these importance scores to create a new representation of the image that highlights the important parts and downplays the less important ones. This new representation is then used for the task at hand, like image classification. Before jumping into the vision transformer architecture, I’m going to provide a quick run-down of how the transformer neural network operates to make it clear how it is extrapolated to vision transformers.
transformer architecture
Before diving into the specifics of vision Transformers, let us first understand the basics of attention and multi-head attention that are presented in the original transformer paper. The Transformer is a model described in the ‘Attention Is All You Need’ (Vaswani et al., 2017) paper. It uses a method called self-attention to achieve better performance than existing methods like CNNs and LSTMs. The Multi-Head Attention section marked in the figure below is key to understanding how Transformers function, with it being similar to skip-joining techniques found in ResNet.
architecture
The attention mechanism used by Transformers involves three elements: Query (Q), Key (K), and Value (V). In essence, it computes the association strength between two tokens using Q and K, then multiplies it with V for each key token.
Single Head (Self-Attention)
Defining the computation of Q, K, and V as one head, we can now define the multi-head attention mechanism. The single-head attention shown above requires use of Q and K values. However, multi-head attention has projections matrices W_i^Q, W_i^K, and W_i^V for each head, which transform feature values to compute attention weights.
Multi-Head Attention
xMulti-head attention provides an advantage because it allows different sections of a sequence to be attended to differently each time. This effectively means that:
Positional information is captured better due to every head attending to separate input segments. Combining them yields a stronger representation in general.
Each head also captures various contextual details by correlating words uniquely.
Now that we’ve gotten that out of the way, let’s jump into the specifics of the Vision Transformer architecture!
vision transformer architecture
At at a very high level, a Vision Transformer really only caries out ~ 7 steps, although this is a little bit oversimplified. With the ViT architecture, the images are viewed as ‘sequences’, enabling the prediction of class labels independently to capture image structure. The input images are regarded as a series of patches, with each patch being transformed into a singular vector by combining the channels of all pixels within the patch and subsequently projecting it linearly to the intended input dimension. The 7 steps of the architecture can be summarized as follows:
- Convert image into patches
- Flatten patches
- Produce low-dimensional linear embeddings from patches
- Add positional embeddings
- Feed sequence to a standard transformer encoder
- Pretrain model with image labels
- Finetune on downstream dataset for image classification
Before jumping into the very specifics, it’s important to understand the blocks of the ViT architecture. Within the encoder of the Vision Transformer, there are multiple blocks and within each one, there are 3 very important components:
- Layer Norm
- Multihead Attention Network
- Multi-Layer Perceptrons
Layer Normalization ensures the training process remains stable and enables the model to adapt to the variances present among the training images.
The Multi-head Attention Network (MSP) is tasked with generating attention maps based on the embedded visual tokens provided. These attention maps assist the network in directing its focus toward the most vital areas within the image, such as objects. The concept of attention maps aligns with that traditionally found in computer vision literature, including saliency maps and alpha-matting.
The MLP, a two-layer classification network featuring the GELU (Gaussian Error Linear Unit) at the final layer, is commonly referred to as the MLP head. It serves as the transformer’s output, and by applying softmax to this output, classification labels can be obtained, particularly in cases of Image Classification.
The research paper mentions when using a vision transformer, the image is first split into patches, converted into linear embeddings and then it is finally passed into the model. These features / embeddings are passed into the Multilayer Perceptron (MLP) head model for classification. To understand the MLP better, here’s a code implementation of it:
This class, which extends PyTorch’s nn.Module
, initializes the network with parameters for input, hidden, and output features, as well as an optional dropout probability. The network consists of two linear layers (fc1
and fc2
) separated by a Gaussian Error Linear Unit (GELU) activation function (act
). The dropout is applied to both the hidden and output layers to mitigate overfitting during training. The forward method outlines the sequence of operations during a forward pass: applying the first linear transformation, activating with GELU, and then applying dropout, followed by the second linear transformation and dropout.
The overall goal of the Multi-Layer Perceptron (MLP) is to model complex relationships within input data and make predictions or classifications. An MLP is a type of artificial neural network that consists of multiple layers of nodes (neurons) with interconnected weights. In the provided code, the MLP has an input layer with a specified number of features, a hidden layer with a customizable number of nodes and GELU activation, and an output layer with the desired number of output features.
One thing to note is that the transformer architecture is extremely compute heavy. In simple terms, it takes a lot of power to do relatively simple things in the transformer architecture. That model was initially made for natural language, now imagine taking a picture, converting it into the right input and passing it in. Let’s take an image of size n x n, after flattening the image, we have n² pixels and now our attention matrix would be n² x n² to see which pixels attend to one another. As n grows, this operation becomes more and more compute heavy. However, the idea of this paper is to break the image down into square patches, similar to windows.
These patches are flattened and passed through a single Feed Forward layer to get the linear patch projection. The transformer uses constant latent vector size (D) throughout all its layers, thus when flattening the patches, they are mapped to D dimensions with a trainable linear projection; these patches are fed through a single Feed Forward layer to obtain the linear patch projection. Another key thing to note is that the Feed Forward layer also has a embedding matrix. The output of the projections are called patch embeddings.
The vision transformer prepended a learnable embedding to the sequence of embedded patches (from earlier). The state of the learnable embeddings at the output of the Transformer encoder acts as the image representation. During the pre-training and fine-tuning stage, a classification head is attached to the sequence of embedding patches; it’s implemented by the MLP with one hidden layer at pre-training time and by a single linear layer at fine-tuning. This entire procedure is heavily influenced by the original BERT paper, specifically the concatenation of a learnable class with other patch projections.
Another key component of this vision transformer architecture are the position embeddings. The positional embeddings are added to the patch projections to retain position information. This is really important because regular transformers is that the sequence order isn’t enforced. Standard learnable 1D position embeddings are used; the result of this is a sequence of embedding vectors, which are the input to the encoder.
The encoder of the Transformer consist of many alternating layers of mutliheaded self-attention and MLP (Multi-Layer Perceptron) blocks. The output of the encoder is sent into a MLP for image classification. This MLP reduces the dimensionality of the vectors to the number of classes in the classification task. It uses a softmax function to output a probability distribution over the classes, providing the final classification result. Overall, putting all these pieces together is exactly how a vision transformer works! If you want to learn more, here is the link to the original paper.
If you’ve made it this far, thank you for taking time out of your day to read this article and I hope it was insightful and valuable!
If you have any questions regarding this article or just want to connect, you can find me on LinkedIn or my personal website :)