Adam Optimization
Adam (Adaptive Moment Estimation) is a widely used optimization algorithm in deep learning. It combines the momentum technique from Stochastic Gradient Descent (SGD) with the adaptive learning rate of RMSprop (Root Mean Square Propagation). Adam can also be viewed as an extension of AdaDelta, refining the learning rate adjustments for each parameter.
It adapts the learning rate for each parameter, making the training process faster and more efficient, especially for large datasets. Adam’s ability to handle sparse gradients and adapt dynamically to different scenarios makes it highly efficient for training deep neural networks.
Working of Adam Optimization
Adam works by computing adaptive learning rates for each parameter, which helps in adjusting the step size during optimization. It maintains two moving averages for each parameter: one for the gradients (momentum), and the other for the squared gradients (similar to RMSProp). These moving averages help to smooth out noisy updates and ensure consistent progress towards the minimum.
One of Adam’s key features is bias correction, which adjusts the moving averages at the start of training to prevent initial values from being too small. This adjustment enhances Adam’s stability during the initial stages of training. Furthermore, Adam adapts to different problem structures and works well with non-stationary data, such as datasets that change over time or have dynamic properties.
Finally, because Adam adjusts learning rates individually for each parameter, it excels in problems where gradients vary widely or parameters have differing levels of importance.
Key Advantages
- Adaptive Learning Rates: Adam adapts learning rates for each parameter, leading to more efficient training, especially in complex models.
- Bias Correction: This feature is significant in the early stages of training, as it adjusts the moving averages to prevent them from being too small, enhancing stability.
- Robustness: Adam performs well with noisy gradients and varying data distributions, making it suitable for real-world applications.
Example
MNIST Classification with Adam Optimizer
This example shows how to train a neural network using the Adam optimizer with TensorFlow/Keras on the MNIST dataset, which consists of 70,000 grayscale images of handwritten digits (0-9) commonly used for benchmarking machine learning algorithms:
import tensorflow as tffrom tensorflow.keras.datasets import mnistfrom tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Dense, Flattenfrom tensorflow.keras.utils import to_categorical# Load the MNIST dataset(x_train, y_train), (x_test, y_test) = mnist.load_data()# Normalize the images to values between 0 and 1x_train, x_test = x_train / 255.0, x_test / 255.0# One-hot encode labelsy_train, y_test = to_categorical(y_train, 10), to_categorical(y_test, 10)# Build and compile the modelmodel = Sequential([Flatten(input_shape=(28, 28)),Dense(128, activation='relu'),Dense(10, activation='softmax')])model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),loss='categorical_crossentropy',metrics=['accuracy'])# Train the modelmodel.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2)# Evaluate the modeltest_loss, test_acc = model.evaluate(x_test, y_test)print(f'Test Accuracy: {test_acc:.4f}')
The above code generates the following output:
Test Accuracy: 0.9731
Note: The output will change with each run due to random weight initialization, data shuffling, and the stochastic nature of the Adam optimizer.
All contributors
- Anonymous contributor
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn AI on Codecademy
- Skill path
Build a Machine Learning Model
Learn to build machine learning models with Python.Includes 10 CoursesWith CertificateBeginner Friendly23 hours - Career path
Data Scientist: Machine Learning Specialist
Machine Learning Data Scientists solve problems at scale, make predictions, find patterns, and more! They use Python, SQL, and algorithms.Includes 27 CoursesWith Professional CertificationBeginner Friendly90 hours