Datasets and DataLoaders
PyTorch provides two essential data handling abstractions: torch.utils.data.Dataset
and torch.utils.data.DataLoader
. These tools are crucial for efficient data management and preprocessing during model training.
Creating a Custom Dataset
To create a custom dataset in PyTorch, torch.utils.data.Dataset
should be inherited and the following methods should be overridden:
__init__(self)
: Initializes the dataset, typically loading data into memory or setting up file paths.__len__(self)
: Returns the total number of samples in the dataset.__getitem__(self, idx)
: Retrieves a sample and its corresponding label at the specified indexidx
.
Example
Here is an example of a custom dataset class creation:
import torchfrom torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):sample = self.data[idx]label = self.labels[idx]return sample, label
DataLoader
The DataLoader
class in PyTorch is used to load data in batches, which is essential for training neural networks efficiently. It also provides options for shuffling the data and loading it in parallel using multiprocessing.
Example
Here is an example of using the DataLoader
class:
from torch.utils.data import DataLoader# Assuming custom_dataset is an instance of CustomDatasetdata_loader = DataLoader(dataset=custom_dataset, batch_size=4, shuffle=True)for batch in data_loader:samples, labels = batch# Training code here
Built-in Datasets
PyTorch also provides several built-in datasets that can be used with DataLoader
without needing to define a custom dataset class. These are available in libraries such as torchvision
, torchaudio
, and torchtext
.
Example
Here is an example of loading a built-in dataset:
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])train_dataset = datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
All contributors
- Anonymous contributor
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.