Codecademy Logo

Image Classification with PyTorch

PyTorch Image Models

PyTorch provides comprehensive tools for vision tasks through libraries like torchvision:

  • Classification: Assigning labels to entire images
  • Detection: Locating and identifying objects with bounding boxes
  • Segmentation: Pixel-level classification of image regions
from torchvision import datasets, transforms, models
# Load a pre-built model for classification
resnet = models.resnet50(pretrained=True)
# Load a dataset
cifar10 = datasets.CIFAR10(root='./data',
train=True,
download=True,
transform=transforms.ToTensor())

PyTorch DataLoaders

DataLoaders in PyTorch are essential for managing image data. They efficiently handle batching, shuffling, and transformation during training. This is crucial for optimizing model performance and ensuring variability across training epochs.

from torch.utils.data import DataLoader
# Create a DataLoader with batch size of 64
# Shuffle training data to prevent overfitting
dataloader = DataLoader(dataset,
batch_size=64,
shuffle=True)
# Usage in training loop
for images, labels in dataloader:
# Each iteration loads a batch of 64 images
outputs = model(images)
loss = criterion(outputs, labels)
# ...continue with backpropagation

Image Transformation

Image transformations standardize data for model input:

  • Resize: Convert images to uniform dimensions
  • Normalize: Scale pixel values to specific range
  • ToTensor: Convert images to PyTorch tensors

Transformations are applied sequentially and should be identical for training and testing sets (except augmentations).

from torchvision import transforms
# Create transformation pipeline
transform = transforms.Compose([
transforms.Resize((64, 64)), # Resize to 64x64 pixels
transforms.ToTensor(), # Convert to tensor, scale to [0.0, 1.0]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize RGB channels
])
# Apply transformations when loading dataset
dataset = datasets.CIFAR10(root='./data',
train=True,
transform=transform,
download=True)

Image Augmentations

Pre-processing images using augmentations such as flipping, rotating, and color jittering can enhance model performance by providing diverse image representations. Image augmentations create diverse variants of training images to improve model generalization. Augmentations are applied only to training data, not testing/validation data. These techniques help prevent overfitting, ensuring the vision model generalizes well to new data.

  • Flipping: Mirror images horizontally/vertically
  • Rotation: Change image orientation
  • Color jittering: Adjust brightness, contrast, saturation
# Training transforms with augmentations
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 50% chance of flipping horizontally
transforms.RandomRotation(15), # Rotate ±15 degrees
transforms.ColorJitter(brightness=0.2), # Adjust brightness by ±20%
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Testing transforms without augmentations
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

Python CNN Basics

Convolutional Neural Networks (CNNs) excel at image tasks through specialized layers:

  • Convolutional layers: Extract spatial features using filters
  • Pooling layers: Reduce dimensionality and parameter count
  • Fully connected layers: Perform classification based on extracted features
  • Compared to standard neural networks, CNNs require fewer parameters and capture spatial relationships between pixels.

CNNs are the backbone for many vision applications like image classification.

import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# Convolutional layer: 3 input channels, 12 filters, 3x3 kernel
self.conv1 = nn.Conv2d(3, 12, kernel_size=3, padding=1)
# Fully connected layers
self.fc1 = nn.Linear(12 * 16 * 16, 64)
self.fc2 = nn.Linear(64, 10) # 10 output classes
def forward(self, x):
# Apply convolution and ReLU activation
x = F.relu(self.conv1(x))
# Apply max pooling (2x2)
x = F.max_pool2d(x, 2)
# Flatten for fully connected layer
x = x.view(x.size(0), -1)
# Pass through fully connected layers
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x

PyTorch Conv2d Basics

A convolutional layer is essential in Convolutional Neural Networks (CNNs). In PyTorch, you initialize it using nn.Conv2d. Customize your setup with the number of input nodes, filters, kernel size, and padding to tailor-fit your neural network’s needs.

import torch
import torch.nn as nn
conv_layer = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(3, 3), padding=1)
# Example input with dimensions (batch_size=1, channels=3, height=32, width=32)
input_tensor = torch.randn(1, 3, 32, 32)
# Forward pass
output_tensor = conv_layer(input_tensor)
print(output_tensor.shape) # Expected: [1, 16, 32, 32]

Vision Transformer ViTs

Vision Transformers, or ViTs, revolutionize how machines analyze visual data by adapting the attention mechanism from traditional transformers to comprehend images. ViTs are adept at tasks like image classification, object detection, and image segmentation. Vision Transformers adapt the transformer architecture for images by:

  • Splitting images into patches (like tokens in NLP)
  • Applying linear projection to create patch embeddings
  • Adding positional embeddings to preserve spatial information
  • Using a [class] token for global image representation
  • Processing patch embeddings through transformer encoder layers with self-attention

Google ViT Base

Google’s pre-trained ViT models can be easily loaded through Hugging Face:

  • "google/vit-base-patch16-224" indicates:
    • “base” model size
    • 16x16 pixel patches
    • 224x224 input resolution
  • AutoImageProcessor handles image preprocessing
  • AutoModelForImageClassification loads model weights
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
# Load pre-trained ViT processor and model
model_name = "google/vit-base-patch16-224"
processor = AutoImageProcessor.from_pretrained(model_name)
vit_model = AutoModelForImageClassification.from_pretrained(model_name)
# Process an image for the model
image = load_image("cat.jpg") # Load your image
inputs = processor(images=image, return_tensors="pt")
# Get predictions
with torch.no_grad():
outputs = vit_model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()

Using DETR Models

Object detection involves both classification and localization:

  • Classification: Identifying what objects are in the image
  • Localization: Finding where objects are using bounding boxes DETR (DEtection TRansformer) approaches this as a set prediction problem, using transformers to directly output a fixed set of predictions.
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
# Load DETR model and processor
model_name = "facebook/detr-resnet-50"
processor = DetrImageProcessor.from_pretrained(model_name)
model = DetrForObjectDetection.from_pretrained(model_name)
# Process image and get predictions
image = load_image("street_scene.jpg") # Load your image
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Get predicted boxes and classes
pred_boxes = outputs.pred_boxes[0] # Bounding box coordinates
pred_scores = outputs.logits[0].softmax(-1) # Class probabilities
pred_labels = pred_scores.argmax(-1) # Predicted class labels

Facebook DETR Model

The pre-trained DETR model from Facebook’s AI can be easily accessed using Hugging Face’s transformers module. Use DetrImageProcessor for processing images and DetrForObjectDetection for the model. DETR integrates a CNN backbone, Transformer Encoder, Decoder, and Feedforward Neural Network, making it robust for object detection tasks.

from transformers import DetrImageProcessor, DetrForObjectDetection
# Load pre-trained model and processor
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

AI Object Detection

Object detection models are evaluated using:

  • IoU (Intersection over Union): Measures bounding box overlap
  • IoU = Area of Intersection / Area of Union (range: 0-1)
  • Precision: Accuracy of positive predictions
  • Recall: Ability to find all relevant objects
  • mAP (mean Average Precision): Average of precision across different IoU thresholds A detection is considered correct when it has the right class AND IoU > threshold.
# Example of IoU calculation
def calculate_iou(box1, box2):
"""
Calculate IoU between two bounding boxes
Each box format: [x1, y1, x2, y2] (top-left and bottom-right corners)
"""
# Calculate intersection coordinates
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
# Calculate intersection area
intersection = max(0, x2 - x1) * max(0, y2 - y1)
# Calculate union area
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
union = box1_area + box2_area - intersection
# Calculate IoU
iou = intersection / union if union > 0 else 0
return iou

Python Transfer Learning

Transfer learning leverages pre-trained models to improve performance on new tasks:

  • Start with a model pre-trained on a large dataset
  • Freeze early layers to retain general feature extraction
  • Fine-tune later layers for specific task requirements
  • Adapt the final classification layer to the new number of classes

This approach requires less data and training time than training from scratch.

# Fine-tuning a pre-trained ViT for a new classification task
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch.nn as nn
import torch.optim as optim
# Load pre-trained model
model_name = "google/vit-base-patch16-224"
processor = AutoImageProcessor.from_pretrained(model_name)
vit_model = AutoModelForImageClassification.from_pretrained(model_name)
# 1. Replace classification head for new task (10 classes)
vit_model.classifier = nn.Linear(vit_model.classifier.in_features, 10)
# 2. Freeze feature extraction layers
for param in vit_model.vit.parameters():
param.requires_grad = False
# 3. Unfreeze specific layers to fine-tune
# Unfreeze last encoder layer
for param in vit_model.vit.encoder.layer[11].parameters():
param.requires_grad = True
# 4. Set up optimizer with different learning rates
optimizer = optim.AdamW([
{'params': vit_model.classifier.parameters(), 'lr': 0.0003},
{'params': vit_model.vit.encoder.layer[11].parameters(), 'lr': 0.0001}
], weight_decay=0.001)
# 5. Train with fine-tuning
criterion = nn.CrossEntropyLoss()
# (Training loop would follow)

Learn more on Codecademy