PyTorch .reshape()

nelsonboamortesantiago's avatar
Published Sep 23, 2024
Contribute to Docs

The torch.reshape() method reshapes a specified input tensor to a given shape while keeping the same data and number of elements. When possible, the returned tensor will have no data copied, however, this behaviour is not guaranteed in all cases.

  • Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!
    • Includes 6 Courses
    • With Professional Certification
    • Beginner Friendly.
      75 hours
  • Learn the basics of Python 3.12, one of the most powerful, versatile, and in-demand programming languages today.
    • With Certificate
    • Beginner Friendly.
      24 hours

Syntax

torch.reshape(input, shape)
  • input: A PyTorch tensor that you want to reshape.
  • shape: A tuple or list of integers specifying the desired new shape.

Example

The following code reshapes a 2D tensor size with shape (2, 2) into a 1D tensor using torch.reshape(), flattening all elements into a single row:

import torch
# Define the tensor
size = torch.tensor([[10, 11], [12, 13]])
# Reshape the tensor
reshaped_size = torch.reshape(size, (-1,)) # Note the comma to make it a tuple
# Print the reshaped tensor
print(reshaped_size)

The example above returns the following output:

tensor([10, 11, 12, 13])

Codebyte Example

Run the following codes to understand how the torch.reshape() method works:

Code
Output
Code
Output

All contributors

Contribute to Docs

Learn PyTorch on Codecademy

  • Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!
    • Includes 6 Courses
    • With Professional Certification
    • Beginner Friendly.
      75 hours
  • Learn the basics of Python 3.12, one of the most powerful, versatile, and in-demand programming languages today.
    • With Certificate
    • Beginner Friendly.
      24 hours