.reshape()

nelsonboamortesantiago's avatar
Published Sep 23, 2024
Contribute to Docs

The torch.reshape() method reshapes a specified input tensor to a given shape while keeping the same data and number of elements. When possible, the returned tensor will have no data copied, however, this behaviour is not guaranteed in all cases.

Syntax

torch.reshape(input, shape)
  • input: A PyTorch tensor that you want to reshape.
  • shape: A tuple or list of integers specifying the desired new shape.

Example

The following code reshapes a 2D tensor size with shape (2, 2) into a 1D tensor using torch.reshape(), flattening all elements into a single row:

import torch
# Define the tensor
size = torch.tensor([[10, 11], [12, 13]])
# Reshape the tensor
reshaped_size = torch.reshape(size, (-1,)) # Note the comma to make it a tuple
# Print the reshaped tensor
print(reshaped_size)

The example above returns the following output:

tensor([10, 11, 12, 13])

Codebyte Example

Run the following codes to understand how the torch.reshape() method works:

Code
Output
Loading...
Code
Output
Loading...

All contributors

Contribute to Docs

Learn PyTorch on Codecademy