.reshape()
Published Sep 23, 2024
Contribute to Docs
The torch.reshape()
method reshapes a specified input tensor to a given shape while keeping the same data and number of elements. When possible, the returned tensor will have no data copied, however, this behaviour is not guaranteed in all cases.
Syntax
torch.reshape(input, shape)
input
: A PyTorch tensor that you want to reshape.shape
: A tuple or list of integers specifying the desired new shape.
Example
The following code reshapes a 2D tensor size with shape (2, 2)
into a 1D tensor using torch.reshape()
, flattening all elements into a single row:
import torch# Define the tensorsize = torch.tensor([[10, 11], [12, 13]])# Reshape the tensorreshaped_size = torch.reshape(size, (-1,)) # Note the comma to make it a tuple# Print the reshaped tensorprint(reshaped_size)
The example above returns the following output:
tensor([10, 11, 12, 13])
Codebyte Example
Run the following codes to understand how the torch.reshape()
method works:
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Career path
Computer Science
Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!Includes 6 CoursesWith Professional CertificationBeginner Friendly75 hours - Course
Learn Python 3
Learn the basics of Python 3.12, one of the most powerful, versatile, and in-demand programming languages today.With CertificateBeginner Friendly23 hours