Variational Autoencoder Tutorial: VAEs Explained
Have you wondered how machines create realistic new images or data instead of copying them? Traditional autoencoders compress and reconstruct data well, but can’t generate meaningful new examples because they don’t learn the full data distribution. Variational Autoencoders (VAEs) address this using a probabilistic approach to learn continuous, meaningful latent representations.
This article will explore VAEs and how to implement them using PyTorch.
What is a Variational Autoencoder (VAE)?
Variational Autoencoders (VAEs) are a powerful type of neural network and a generative model that extends traditional autoencoders by learning a probabilistic representation of data. Unlike regular autoencoders that create fixed representations, VAEs create probability distributions. These distributions have a mean (center point) and variance (spread). This allows the model to generate new data by sampling from the latent space, creating diverse and meaningful outputs.
The VAE architecture consists of three main parts:
Encoder: This neural network maps the input data to parameters of a probability distribution, specifically, the mean and log variance of a latent Gaussian distribution. This probabilistic output allows the model to capture uncertainty in the representation.
Latent space sampling: The model samples a latent vector using the mean and variance. This sampling is made differentiable through the reparameterization trick, which lets the network be trained end-to-end with gradient descent.
Decoder: The decoder takes the sampled latent vector and reconstructs it back into the original data space, producing an output similar to the input but also allowing for novel variations.
Unlike traditional autoencoders that compress and decompress data, VAEs learn a smooth, continuous latent space that can generate new, meaningful outputs.
Now that we understand the VAE architecture and how it enables generative modeling, let’s explore how VAEs differ from traditional autoencoders and what makes their approach unique.
Generative AI Models: Generating Data Using Variational Autoencoders
Master VAEs for image generation: Learn probabilistic encoders and decoders. Train models on multichannel color images using Python in Colab.Try it for freeHow do VAEs differ from traditional autoencoders
Traditional autoencoders and Variational Autoencoders both learn to compress and reconstruct data, but they differ fundamentally in their approach and capabilities. The table here highlights the key differences between the two:
Feature | Traditional autoencoder | Variational autoencoder (VAE) |
---|---|---|
Latent Representation | Deterministic fixed points | Probabilistic distribution (mean & variance) |
Sampling | Not applicable; uses fixed latent vector | Sampling enabled via reparameterization trick |
Objective/Loss Function | Minimize reconstruction error only | Reconstruction loss + KL divergence regularizer |
Generative Modelling | Limited; not designed for generation | Designed to generate new, diverse data samples |
Encoding Nature | Point estimate of input | Distribution capturing uncertainty |
Behavior | Compresses data | Compresses and creates/generates data |
These differences make VAEs a more powerful choice for tasks that involve generating new data or learning rich latent representations. But why exactly are VAEs favored in so many machine-learning applications?
What is the use of VAE?
VAEs are widely used in generative AI because they can learn structured latent spaces that enable smooth data generation and interpolation. They’re effective in tasks like image synthesis, data imputation, and anomaly detection across healthcare and computer vision domains.
Some of the core advantages and applications of VAEs are:
Generative capabilities: VAEs can generate new, realistic data samples by sampling from a learned latent distribution. Example: Generate new faces, handwritten digits, or audio samples.
Smooth latent space interpolation: The latent space learned by VAEs is continuous and meaningful, enabling smooth transitions between samples. Example: Morphing one face into another by interpolating in latent space.
Representation learning: VAEs learn compressed, informative representations of data, which can be useful for downstream tasks. Example: Feature extraction for classification.
Data imputation: VAEs can infer and fill in missing values because they learn to model the full data distribution. Example: Completing missing pixels in an image or gaps in time-series data.
Anomaly detection: VAEs can detect when an input significantly deviates from the training data by comparing reconstruction loss. Example: Identifying fraudulent transactions or faulty sensor readings.
VAEs offer versatile capabilities that make them a valuable tool in the deep learning toolkit.
Now that we understand VAE applications, let’s see how to build one from scratch using PyTorch.
Implementing VAE with PyTorch
Building a VAE from scratch is the best way to understand how it works. Let’s walk through the full implementation using the MNIST dataset, defining the model architecture, training loop, loss function, and inference steps.
Step 1: Setting up the VAE environment in PyTorch
To get started, we’ll install and import the necessary libraries and prepare the MNIST dataset.
import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoaderimport torch.nn.functional as Fimport numpy as npimport matplotlib.pyplot as plt# Use GPU if availabledevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# Load MNIST datasettransform = transforms.ToTensor()train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
The libraries used in this code are as follows:
torch
,torch.nn
, andtorch.optim
: Core PyTorch libraries used for building neural networks, defining layers, and optimizing model parameters.torch.nn.functional
(F
): Provides functional interfaces for activation functions, loss functions, and other operations that don’t require parameter tracking.torchvision.datasets
andtorchvision.transforms
: Used to download and preprocess the MNIST dataset.torch.utils.data.DataLoader
: Handles batching, shuffling, and loading the dataset efficiently during training.numpy
: Helpful for post-processing and latent space manipulationmatplotlib
: For visualizing reconstructions and generations
Step 2: Building the VAE model architecture in PyTorch
Building the encoder class
The encoder takes an input image and compresses it into a compact latent representation. But unlike a regular autoencoder, it doesn’t output a single point—instead, it outputs two vectors: the mean (μ
) and log-variance (logσ²
). These define a probability distribution from which we’ll later sample a latent vector.
class Encoder(nn.Module):def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):super(Encoder, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.fc_mu = nn.Linear(hidden_dim, latent_dim)self.fc_logvar = nn.Linear(hidden_dim, latent_dim)def forward(self, x):h = torch.relu(self.fc1(x))mu = self.fc_mu(h)logvar = self.fc_logvar(h)return mu, logvar
Here:
- The encoder first flattens the image (28×28 pixels = 784 inputs).
- It then maps this to a hidden layer of size 400 using a ReLU activation.
- From the hidden layer, it produces two outputs:
mu
: the center of the latent distributionlogvar
: the log of the variance (used for sampling and stability)
Implementing the decoder class
The decoder does the reverse, it takes a sampled point from the latent space and tries to reconstruct the original image. This helps the VAE learn to generate new data similar to the training inputs.
class Decoder(nn.Module):def __init__(self, latent_dim=20, hidden_dim=400, output_dim=784):super(Decoder, self).__init__()self.fc1 = nn.Linear(latent_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, output_dim)def forward(self, z):h = torch.relu(self.fc1(z))return torch.sigmoid(self.fc2(h))
This code:
- Takes a 20-dimensional latent vector
z
(sampled from the encoded distribution). - Maps it back to the original data dimensions through a hidden layer.
- The final
sigmoid
ensures pixel values are between 0 and 1.
Creating the main VAE class
The VAE class combines the encoder and decoder and implements the reparameterization trick to keep training differentiable. Since sampling directly from a distribution (e.g., z ~ N(μ, σ²)
) breaks backpropagation, we sample ε ~ N(0,1)
and compute z = μ + σ * ε
instead. This trick allows gradients to flow through the sampling step during training.
class VAE(nn.Module):def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):super(VAE, self).__init__()self.encoder = Encoder(input_dim, hidden_dim, latent_dim)self.decoder = Decoder(latent_dim, hidden_dim, input_dim)def reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar) # Standard deviationeps = torch.randn_like(std) # Random noisereturn mu + eps * std # Sample from the latent distributiondef forward(self, x):mu, logvar = self.encoder(x)z = self.reparameterize(mu, logvar)reconstructed = self.decoder(z)return reconstructed, mu, logvar
In this code:
The complete model takes an image, encodes it to a latent distribution, samples from it, and reconstructs the image.
The
reparameterize()
function enables differentiable sampling.logvar
is exponentiated and scaled to get the standard deviation.eps
is random noise from a standard normal distribution.- The final sampled
z
has the desired mean and variance but is differentiable with respect tomu
andlogvar
.
The model outputs the reconstruction along with
mu
andlogvar
for computing the loss.
This trick allows VAEs to be trained end-to-end using standard gradient-based optimization techniques.
Step 3: Defining the loss function – Reconstruction + KL divergence
Training a VAE involves optimizing a composite loss function that balances two goals:
Reconstruction loss: Ensures the output is close to the original input.
KL divergence: Encourages the learned latent distribution to be close to a standard normal distribution.
The total loss is often referred to as the ELBO (Evidence Lower Bound), and we aim to maximize it (or equivalently, minimize its negative).
def loss_function(recon_x, x, mu, logvar):# Reconstruction loss (binary cross entropy)recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')# KL divergence losskl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())return recon_loss + kl_loss
Reconstruction loss: Binary Cross Entropy is used for normalized pixel values (0-1), it measures how accurately the output matches the input.
KL divergence: Encourages
mu
andlogvar
to represent a distribution close to N(0, 1), helping the latent space generalize and avoid overfitting.
Step 4: Training and evaluating the VAE
Now that we’ve defined the VAE model and its loss function, it’s time to train it. We’ll run the training loop for several epochs, calculate the loss, and visualize how well the VAE learns to reconstruct and generate data over time.
import matplotlib.pyplot as plt# Set training parametersepochs = 10learning_rate = 1e-3# Initialize model, optimizermodel = VAE().to(device)optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# Track lossestrain_losses = []# Training loopmodel.train()for epoch in range(epochs):total_loss = 0for batch_idx, (x, _) in enumerate(train_loader):x = x.view(-1, 784).to(device) # Flatten imagesoptimizer.zero_grad()recon_x, mu, logvar = model(x)loss = loss_function(recon_x, x, mu, logvar)loss.backward()optimizer.step()total_loss += loss.item()avg_loss = total_loss / len(train_loader.dataset)train_losses.append(avg_loss)print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")# Plotting the training lossplt.plot(train_losses)plt.title("VAE Training Loss")plt.xlabel("Epoch")plt.ylabel("Loss")plt.grid(True)plt.show()
This training loop does the following:
- Loads data in batches and sends it to the appropriate device (CPU/GPU).
- Performs forward and backward passes using the loss function we defined earlier.
- Updates model weights using the Adam optimizer.
- Tracks and prints the average loss per epoch to monitor convergence.
On executing this code, a possible output generated can be:
Epoch 1, Loss: 164.3773Epoch 2, Loss: 121.5494Epoch 3, Loss: 114.6132Epoch 4, Loss: 111.6174Epoch 5, Loss: 109.8854Epoch 6, Loss: 108.7622Epoch 7, Loss: 107.9547Epoch 8, Loss: 107.2337Epoch 9, Loss: 106.7546Epoch 10, Loss: 106.2749
The epoch-loss values show how the model’s total loss decreases with training, indicating improved reconstruction and latent learning. The loss plot visually confirms this trend, showing a downward curve as the VAE converges.
Step 5: Testing the VAE by generating new samples
Once the VAE is trained, we can evaluate its generative capabilities by sampling from the learned latent space and observing the outputs.
Sampling from the latent space
We’ll sample random points from a standard normal distribution and feed them through the decoder to generate synthetic images:
model.eval()# Generate random latent vectorswith torch.no_grad():z = torch.randn(16, 20).to(device) # 16 samples, 20-dimensional latent spacegenerated = model.decoder(z).cpu()generated = generated.view(-1, 1, 28, 28)# Plot generated samplesfig, axes = plt.subplots(2, 8, figsize=(12, 4))for i, ax in enumerate(axes.flat):ax.imshow(generated[i][0], cmap='gray')ax.axis('off')plt.suptitle("Generated Samples from Latent Space")plt.show()
Here:
- We draw 16 random vectors from a standard normal distribution.
- The
decoder()
method turns each latent vector into an image. - These images are reshaped and plotted to visualize what the VAE has learned.
The output generated by this code will be:
Latent space interpolation
We can also interpolate between two points in the latent space to see smooth transitions in generated images, which is a hallmark of a well-trained VAE.
def interpolate(model, z_start, z_end, steps=10):vectors = torch.stack([z_start * (1 - t) + z_end * t for t in torch.linspace(0, 1, steps)])with torch.no_grad():samples = model.decode(vectors.to(device)).cpu()return samples.view(-1, 1, 28, 28)# Sample two random pointsz1 = torch.randn(1, 20)z2 = torch.randn(1, 20)# Interpolate and visualizeinterpolated_images = interpolate(model, z1, z2)# Plot resultsfig, axes = plt.subplots(1, 10, figsize=(15, 2))for i, ax in enumerate(axes.flat):ax.imshow(interpolated_images[i][0], cmap='gray')ax.axis('off')plt.suptitle("Latent Space Interpolation")plt.show()
This code:
- Creates a smooth blend between two latent vectors.
- The decoder is used to visualize the intermediate representations as images.
- Helps us understand how changes in the latent space relate to output changes.
The output produced by this code will be:
After building and training a basic VAE, you might wonder if we can improve this design. What if we want more interpretable features, handle labels explicitly, or model sequences like text or time series? This is where the diverse types of VAEs come in.
Types of VAEs
VAEs have inspired a variety of extensions that adapt the original design to suit specific tasks or improve certain behaviors. Listed are some of the most popular VAE variants.
β-VAE (Beta-VAE)
β-VAE introduces a hyperparameter β
to control the weight of the KL divergence in the loss function. This encourages the model to learn disentangled latent representations, where each dimension captures an independent variation factor.
Use cases:
- Unsupervised disentanglement of features (e.g., shape vs. orientation in images).
- Fairness and interpretability in ML models.
- Controlled image or video editing.
Conditional VAE (CVAE)
CVAE enhances the standard VAE by conditioning both the encoder and decoder on external labels or attributes. This makes the model suitable for tasks where the generation should depend on specific conditions.
Use cases:
- Class-specific image generation (e.g., generate a digit “7”).
- Style transfer and attribute-controlled generation.
- Semi-supervised learning and data augmentation.
Adversarial Autoencoder (AAE)
AAE uses a GAN-style discriminator to replace the KL divergence term with an adversarial loss. This helps the encoder more flexibly align its output with a prior distribution and often improves generation quality.
Use cases:
- Semi-supervised classification.
- Anomaly detection using adversarial scores.
- Enhanced visual fidelity in generative tasks.
Variational Recurrent Autoencoder (VRAE)
VRAE integrates recurrent neural networks (RNNs) into the VAE framework, making it capable of effectively handling sequential or time-series data.
Use cases:
- Music generation or synthesis.
- Time-series forecasting and compression.
- Modeling motion, handwriting, or speech patterns.
Hierarchical VAE (HVAE)
HVAE stacks multiple layers of latent variables to create a hierarchy. This allows the model to capture features at different levels of abstraction, which is useful for complex data.
Use cases:
- Modeling large text corpora (sentence- and paragraph-level).
- High-resolution image synthesis.
- Multiscale representation learning for structured data.
Let’s weigh their strengths and limitations and check how they fare as generative models.
Advantages and limitations of VAEs
VAEs, like any tool, come with trade-offs. Understanding both sides can help you decide when and not to use VAEs in your projects.
Advantages of VAEs
Smooth, structured latent spaces: The learned latent space allows for meaningful interpolation and arithmetic operations on data representations.
Stable training: VAEs are trained using standard backpropagation and stochastic gradient descent, making them more stable than GANs.
Generative flexibility: Capable of generating new, diverse data samples from learned distributions.
Probabilistic foundation: The Bayesian approach lends itself well to uncertainty estimation and principled modeling.
Limitations of VAEs
Blurry outputs: VAEs tend to produce less sharp images than GANs due to the often used pixel-wise loss functions.
KL divergence tuning: Balancing the reconstruction loss with the KL divergence term (especially in β-VAEs) can be delicate and dataset-dependent.
Posterior collapse: Sometimes, the encoder learns to ignore the latent variables entirely, collapsing the model into a deterministic autoencoder.
Conclusion
Variational Autoencoders (VAEs) combine neural networks with probabilistic modeling to generate new data by learning meaningful latent spaces. This tutorial covered the basics of VAEs, their differences from traditional autoencoders, and how to build and train one using PyTorch.
For a practical introduction to deep learning and PyTorch, check out Codecademy’s Build Deep Learning Models with PyTorch course to gain hands-on experience with building models.
Frequently asked questions
1. Is VAE better than GAN?
VAEs provide stable training and meaningful latent space representation. GANs generate sharper, more realistic images but are harder to train. Choice depends on the application.
2. Is a VAE supervised or unsupervised?
VAEs are generally unsupervised models since they learn to represent data without labeled outputs.
3. What is z
in a VAE?
z
is the latent variable—a sampled vector from the learned distribution representing compressed information about the input.
4. When should I use VAE vs autoencoder?
Use VAE when generating new data or for smooth latent space interpolation. Use regular autoencoders for compression tasks.
5. What is the main difference between a VAE and a GAN?
VAEs learn probabilistic latent representations and optimize a combined reconstruction and KL loss, while GANs use a generator-discriminator setup for adversarial training to produce realistic outputs.
'The Codecademy Team, composed of experienced educators and tech experts, is dedicated to making tech skills accessible to all. We empower learners worldwide with expert-reviewed content that develops and enhances the technical skills needed to advance and succeed in their careers.'
Meet the full teamRelated articles
- Article
Building a Neural Network using PyTorch
Learn how to build a PyTorch neural network step by step. This tutorial walks you through a complete PyTorch neural network example, covering model creation, training, and evaluation. - Article
Common Applications of Deep Learning
This article reviews some of deep learning's common applications. - Article
Building a Neural Network Model Using TensorFlow
Learn how to build a neural network model in TensorFlow by creating a digits classification model using the MNIST dataset.
Learn more on Codecademy
- Free course
Generative AI Models: Generating Data Using Variational Autoencoders
Master VAEs for image generation: Learn probabilistic encoders and decoders. Train models on multichannel color images using Python in Colab.Intermediate1 hour - Free course
Generative AI Models: Getting Started with Autoencoders
Dive into unsupervised learning with autoencoders. Train models to reconstruct high-dimensional images and denoise corrupted images using PyTorch in colab.Intermediate2 hours - Free course
Generative AI on AWS: Building GenAI Models with Amazon SageMaker
Build cutting-edge generative AI models with Amazon SageMaker. Explore GANs, VAEs, optimization, transfer learning, deployment, and monitoring techniques.Intermediate2 hours