Vanishing Gradient Problem
The Vanishing gradient problem occurs when gradients shrink as they move backward through a deep neural network. This causes slow or stalled training because updates to early layers become extremely small. It often appears in neural networks that use certain activation functions, such as sigmoid or hyperbolic tangent, or when the network has many layers.
How does it occur?
- Deep Architectures: Deeper networks have more layers that can multiply small gradient values.
- Sigmoid or Tanh Activations: These functions squash input values into a narrow range, which can reduce gradient magnitude.
- Poor Weight Initialization: Wrong initial weight scales can cause gradients to vanish.
How to Fix It
- Use ReLU or Related Activations: ReLU functions help avoid squashing the gradient in early layers.
- Proper Initialization: Techniques like Xavier or He initialization maintain stable gradients.
- Batch Normalization: Normalizing layer inputs can stabilize gradient flow.
- Skip Connections: Shortcut paths reduce the effective depth of the network.
Example: Demonstrating and Addressing the Vanishing Gradient Problem
The following PyTorch example shows a simple deep network with sigmoid activation. The gradients in the earliest layers may become too small, slowing training. Switching to ReLU in the final code snippet provides a potential fix:
import torchimport torch.nn as nnimport torch.optim as optim# Deep feedforward network with Sigmoidclass DeepSigmoidNet(nn.Module):def __init__(self):super().__init__()self.layers = nn.Sequential(nn.Linear(100, 128),nn.Sigmoid(),nn.Linear(128, 128),nn.Sigmoid(),nn.Linear(128, 128),nn.Sigmoid(),nn.Linear(128, 10))def forward(self, x):return self.layers(x)# Create random datax = torch.randn(32, 100) # batch of 32y = torch.randint(0, 10, (32,)) # target classesmodel = DeepSigmoidNet()criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.01)# Forward passoutputs = model(x)loss = criterion(outputs, y)# Backward passloss.backward()# Check the gradient norm of the first layergrad_norm = model.layers[0].weight.grad.norm().item()print(f"Gradient norm (Sigmoid net, first layer): {grad_norm:.6f}")# Potential fix: Using ReLUclass DeepReLUNet(nn.Module):def __init__(self):super().__init__()self.layers = nn.Sequential(nn.Linear(100, 128),nn.ReLU(),nn.Linear(128, 128),nn.ReLU(),nn.Linear(128, 128),nn.ReLU(),nn.Linear(128, 10))def forward(self, x):return self.layers(x)model_relu = DeepReLUNet()optimizer = optim.SGD(model_relu.parameters(), lr=0.01)outputs_relu = model_relu(x)loss_relu = criterion(outputs_relu, y)loss_relu.backward()grad_norm_relu = model_relu.layers[0].weight.grad.norm().item()print(f"Gradient norm (ReLU net, first layer): {grad_norm_relu:.6f}")
The above code returns the following output:
Gradient norm (Sigmoid net, first layer): 0.004324Gradient norm (ReLU net, first layer): 0.118170
- DeepSigmoidNet: A fully connected network with multiple layers of sigmoid activation. The gradient often shrinks as it propagates back through each layer.
- Gradient Norm: The code checks the gradient norm of the first layer. A very small value suggests that those parameters receive negligible updates.
- DeepReLUNet: Switching to ReLU reduces the vanishing effect, which can be seen in the larger gradient norm for the first layer.
Using suitable activations, initialization, or techniques like batch normalization and skip connections makes the vanishing gradient problem less severe, making training faster and more reliable.
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn AI on Codecademy
- Career path
Computer Science
Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!Includes 6 CoursesWith Professional CertificationBeginner Friendly75 hours - Course
Learn Python 3
Learn the basics of Python 3.12, one of the most powerful, versatile, and in-demand programming languages today.With CertificateBeginner Friendly23 hours