Articles

How do Vision Transformers Work? Architecture Explained

The transformer architecture was a breakthrough in natural language processing (NLP). Vision transformers adapt the transformer architecture for computer vision tasks by converting an image into a sequence of patches.

In this article, we will discuss the vision transformer (ViT) architecture and its components to understand how they work. We will also discuss the advantages and limitations of vision transformers, and how they are different from convolutional neural network (CNN) models.

  • Learn how to use Python to build image classification models using CNNs and vision transformers in this PyTorch tutorial.
    • With Certificate
    • Intermediate.
      5 hours
  • Learn about what transformers are (the T of GPT) and how to work with them using Hugging Face libraries
    • Intermediate.
      3 hours

What is a vision transformer (ViT)?

Vision transformer (ViT) is a deep learning architecture that adapts the transformer model to computer vision tasks. While traditional CNN models use convolutional layers to capture spatial features for object detection, segmentation, and classification, ViTs take a fundamentally different approach by treating images as a sequence of patches. This architectural shift enables ViTs to use the powerful self-attention mechanism that made transformers successful in NLP, allowing them to model long-range dependencies and global relationships in images more effectively than CNNs.

The ViT architecture consists of four main components: patch embedding, positional encoding, transformer encoder, and classification head. Let’s discuss the vision transformer architecture to see how all these components work together.

Vision transformer architecture

The original vision transformer architecture proposed in the paper “An Image Is Worth 16X16 Words” contains different components such as patch embedding, positional encoding, transformer encoder, multilayer perceptron head, etc, and looks as follows:

Image showing vision transformer (ViT) architecture

To understand the ViT architecture, let’s first examine each component in detail. Later, we will discuss how these components work together for the image classification task.

Understanding the different components of a vision transformer

Patches of input image

The ViT first splits an input image into non-overlapping patches. To create these patches, the input image is divided into sub-images of shape MxMx3, where M is a factor of N. For example, an RGB image of shape 720x720x3 can be split into 2025 patches of shape 16x16x3.

Patch embedding

The image patches are converted into linear vectors of length M2x3. For example, the patches of shape 16x16x3 are transformed into vectors of length 768. The vectors are then processed using a linear neural network layer to create patch embeddings, which act similarly to word embeddings of input tokens in the original transformer architecture.

Positional encoding

To understand the content of an input image, the ViT model needs to retain the spatial information of the image patches. Therefore, it adds identifiers to the embedding vectors of the image patches to highlight where each patch belongs in the image.

CLS token embedding

The ViT model needs a vector representing the image’s content for tasks like image classification or object detection. Hence, it prepends a special learnable embedding vector in addition to patch embedding vectors, which is represented using the CLS token.

Transformer encoder layer

The transformer encoder contains self-attention and neural network layers that capture relationships between embeddings. A ViT model can have any number of transformer blocks, each capturing different aspects of the input images, such as edges, textures, and global shape. If you aren’t aware of how the transformer encoder works, we suggest you read this transformer architecture article, which clearly explains the self-attention mechanism.

The final output of the transformer encoder layer is a sequence of contextual embeddings of the input image patches and the CLS token. Here, the CLS token serves as a compressed representation of the whole image, which is built through multiple self-attention layers. Its contextual embedding is used for image classification, while patch embeddings can be used for segmentation, object detection, and representation learning.

Multilayer perceptron (MLP) head

The multilayer perceptron (MLP) head, also called the classification head, uses the contextual embedding of the CLS token to classify the input image and gives the probabilities of different classes as its output.

Now that we have discussed how each component of a vision transformer works individually, let’s discuss how a vision transformer model works during training and inference.

How vision transformers work?

ViTs work differently during training and inference. The ViT architecture uses the following steps during model training:

  1. The model first splits input images into non-overlapping patches of equal size.
  2. Next, it flattens each patch and creates a linear vector of the pixel values in the patches.
  3. The vectors are then passed to a neural network layer to create patch embeddings, representing each patch’s content.
  4. After this, the model adds positional embeddings to the patch embeddings to specify the spatial location of the vectors in the original image.
  5. Next, the model prepends a learnable CLS token embedding to the sequence of vectors and passes them to the transformer encoder.
  6. The transformer encoders produce contextual embeddings of the CLS token embedding and other patch embeddings.
  7. Finally, the MLP classification head uses the CLS token’s embedding to classify the input image and gives probability scores for different class labels.
  8. The predicted class label is then compared with the actual class label in the dataset using cross-entropy loss, and all the parameters in the patch embedding layer, transformer encoder, and MLP head are updated using the backpropagation mechanism.

Steps 3 to 8 are repeated iteratively until the model becomes proficient at the image classification task and achieves acceptable accuracy.

During inference, the model follows steps 1 to 7 to predict the class label for a given image. Here, the model only runs the forward pass to generate the class label, and no backpropagation or gradient updates are required.

Vision transformers advantages and limitations

Vision transformers have several advantages due to treating images as a sequence of embedding vectors and using transformer encoders to generate contextual embeddings. The following are some of the key benefits :

  • Global feature modeling: In vision transformers, every patch of an image interacts with every other patch during the self-attention mechanism. Because of this, ViTs efficiently capture long-range dependencies and global relationships in the input images.
  • Efficiency at scale: ViTs tend to be more efficient for a very large model size than equally large CNNs.
  • Transfer learning: ViTs pretrained on massive datasets transfer extremely well to other tasks like object detection and image segmentation. We can also fine-tune pre-trained ViTs on custom datasets to build powerful image processing models without investing much time and resources.
  • Unified architecture: Vision transformers use the same encoder that we use to process text data. This allows shared research advances. Any improvement to the transformer in natural language processing can be directly adopted in vision transformers. It also encourages us to build general-purpose foundation models to process text and image inputs.

Despite these advantages, ViTs also have certain limitations:

  • High Computational Cost: The self-attention layers in a vision transformer compare each embedding vector to every other embedding vector to compute relationships. As the image resolution and number of patches increase, the number of computations increases quadratically, making the model training and inference very expensive.
  • Data hungry: Vision transformers are data hungry and require large-scale datasets during model training. They lack translational equivariance and invariance, and we need to prepare training data using geometric transformations to help the ViT models generalize to unseen data.
  • Longer training time: ViTs generally require more epochs and careful regularization to converge.
  • Less efficient for small datasets: Vision transformers can be less efficient for small datasets as they risk overfitting when pretrained weights aren’t available.
  • Deployment costs: Vision transformers are compute and memory-intensive, making it harder to deploy ViTs on edge devices.
  • Positional encoding dependency: ViTs rely on positional embeddings to detect spatial features in the image. Unlike CNNs, ViTs don’t scan the local features of the input image using kernels. This can make ViTs less robust to changes in scale or aspect ratio of images compared to CNNs, which naturally capture locality.

To overcome these challenges, new vision transformer architectures have been introduced to reduce the computation costs and improve local feature detection. For example, the Swin transformer architecture uses a local window-based self-attention mechanism instead of global attention. It also uses hierarchical feature maps like CNNs and scales well to high-resolution images. Similarly, Pyramid Vision Transformer (PVT) models reduce sequence length gradually, like CNN pooling, to make ViTs more suitable for object detection and segmentation tasks.

Both swin and pyramid vision transformers borrowed functionalities from CNNs to improve the local feature detection and reduce computational requirements in vision transformers. Let’s discuss the similarities and differences between vision transformers and convolutional neural networks.

Vision transformers (ViTs) vs convolutional neural networks (CNNs)

Convolutional neural networks (CNNs) use convolutional and pooling layers, whereas ViTs use patch embedding and a self-attention mechanism to capture spatial features in an image. Due to this, there are significant differences in how CNNs work compared to ViTs. The following are some of the important differences between ViTs and CNNs:

Aspect Vision transformers (ViTs) Convolutional neural networks (CNNs)
Data requirement Need large datasets and overfit with small datasets Can work well with large and small datasets alike
Computation cost High computation cost due to quadratic time complexity w.r.t patch embedding length Less computation costs due to linear time complexity w.r.t input image size
Training speed Slower to converge and needs heavy regularization Converges faster with stable training
Scalability Scales very well with data and model size Performance saturates at large scale
Transfer learning Excellent with pretrained ViTs Good but task-specific
Deployment Heavy compute and memory required Lightweight variants available for edge deployment
Inductive bias Weak inductive bias Strong inductive bias due to locality, translational invariance, and equivariance
Use cases Large-scale classification, multimodal learning Image classification, object detection and segmentation

Conclusion

Vision Transformers mark a significant advancement in computer vision by enabling global context learning and strong scalability. While CNNs remain efficient on smaller datasets, ViTs excel with large-scale data and flexible tasks. The future of vision models will likely combine both strengths, as in Swin and Pyramid Vision transformers. In this article, we discussed the basics of vision transformer models, including the ViT architecture, working, advantages, and limitations.

To learn more about vision transformers, you can take the Learn Image Classification with PyTorch course that discusses using PyTorch to build image classification models using CNNs and vision transformers. You might also like this Intro to Midjourney course that helps you peek into the generative image-based artificial intelligence using Midjourney.

Frequently asked questions

1. Is ViT better than CNN?

ViTs are better than CNNs for training computer-vision applications on large-scale datasets because they can capture global context and scale efficiently. However, CNNs outperform ViTs on small datasets.

2. Is DINO a vision transformer?

DINO isn’t a vision transformer. It’s a self-supervised learning framework for vision transformers that allows a ViT model to learn robust and invariant features from images without needing human-labeled data.

3. Does Yolo use Vision Transformers?

Traditional YOLO models do not use vision transformers. However, newer models like YOLOS and YOLO-Former use vision transformers for object detection tasks.

4. Is BERT a vision transformer?

No. BERT isn’t a vision transformer. It’s an encoder-only transformer model used for natural language processing tasks.

5. What is the difference between a transformer and a vision transformer?

A transformer is a general neural network architecture designed for sequential data, especially text data. A vision transformer is a specific use case of transformer architecture where transformer encoders are used to build computer vision applications by treating images as sequences of patches.

Codecademy Team

'The Codecademy Team, composed of experienced educators and tech experts, is dedicated to making tech skills accessible to all. We empower learners worldwide with expert-reviewed content that develops and enhances the technical skills needed to advance and succeed in their careers.'

Meet the full team

Learn more on Codecademy

  • Learn how to use Python to build image classification models using CNNs and vision transformers in this PyTorch tutorial.
    • With Certificate
    • Intermediate.
      5 hours
  • Learn about what transformers are (the T of GPT) and how to work with them using Hugging Face libraries
    • Intermediate.
      3 hours
  • Learn how to use PyTorch in Python to build text classification models using neural networks and fine-tuning transformer models.
    • With Certificate
    • Intermediate.
      4 hours