Self-Supervised Learning

DaniTellini's avatar
Published Feb 16, 2025
Contribute to Docs

Self-Supervised Learning (SSL) is a machine learning approach that bridges the gap between supervised and unsupervised learning. It enables models to learn meaningful representations from unlabeled data by deriving supervision signals through pretext tasks.

SSL is widely applied in domains such as natural language processing (NLP), computer vision, and speech recognition.

Pipeline

Self-Supervised Learning (SSL) follows a structured pipeline consisting of the following steps:

  1. Pretext Task Creation: Define a task where the model learns meaningful representations by predicting missing or transformed aspects of the input data.
  2. Feature Learning: Train a model using the pretext task to extract useful features from the data.
  3. Fine-Tuning: Transfer the learned features to a downstream task (e.g., classification, object detection).

For example, in computer vision, a common pretext task is predicting missing parts of an image:

# Pseudo code for an SSL pipeline
Initialize model
Load unlabeled dataset

For each epoch:
    - Apply transformation (e.g., masking, rotation, cropping)
    - Train model to predict missing or transformed part
    - Compute loss and update model parameters

Save learned representations
Fine-tune model on downstream task (e.g., classification)

Example

In the example below, a simple self-supervised learning model uses an autoencoder to learn feature representations by reconstructing unlabeled input data:

import torch
import torch.nn as nn
import torch.optim as optim
# Define a simple self-supervised learning model (Autoencoder)
class SimpleSSLModel(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(SimpleSSLModel, self).__init__()
# Encoder: Learns meaningful feature representations
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU()
)
# Decoder: Reconstructs input from learned representations
self.decoder = nn.Sequential(
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid()
)
def forward(self, x):
encoded = self.encoder(x) # Generate feature representation
decoded = self.decoder(encoded) # Reconstruct input from encoded features
return decoded
# Generate random unlabeled input data
data = torch.randn(100, 20)
# Define model, loss function, and optimizer
model = SimpleSSLModel(input_dim=20, hidden_dim=10)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# Training loop for self-supervised learning
for epoch in range(100):
optimizer.zero_grad()
reconstructed = model(data)
loss = criterion(reconstructed, data)
loss.backward()
optimizer.step()
print("Self-supervised model training complete!")

This approach trains a model to reconstruct input data, allowing it to learn meaningful features without labeled supervision.

All contributors

Contribute to Docs

Learn AI on Codecademy