PyTorch .reshape()
Published Sep 23, 2024
The torch.reshape() method reshapes a specified input tensor to a given shape while keeping the same data and number of elements. When possible, the returned tensor will have no data copied, however, this behaviour is not guaranteed in all cases.
Syntax
torch.reshape(input, shape)
input: A PyTorch tensor that you want to reshape.shape: A tuple or list of integers specifying the desired new shape.
Example
The following code reshapes a 2D tensor size with shape (2, 2) into a 1D tensor using torch.reshape(), flattening all elements into a single row:
import torch# Define the tensorsize = torch.tensor([[10, 11], [12, 13]])# Reshape the tensorreshaped_size = torch.reshape(size, (-1,)) # Note the comma to make it a tuple# Print the reshaped tensorprint(reshaped_size)
The example above returns the following output:
tensor([10, 11, 12, 13])
Codebyte Example
Run the following codes to understand how the torch.reshape() method works:
Learn PyTorch on Codecademy
- Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!
- Includes 6 Courses
- With Professional Certification
- Beginner Friendly.75 hours
- Learn the basics of Python 3.12, one of the most powerful, versatile, and in-demand programming languages today.
- With Certificate
- Beginner Friendly.24 hours