PyTorch .squeeze()
Anonymous contributor
Published Nov 25, 2024
Contribute to Docs
The PyTorch .squeeze() function removes dimensions of size 1 from a given tensor.
Syntax
torch.squeeze(input, dim=None)
input: The input tensor from which dimensions of size1will be removed.dim(Optional): A specific dimension to squeeze. If provided, only the dimension of size1at that index will be removed. If not provided, all dimensions with size1will be removed.
Example
The following example demonstrates how the .squeeze() function removes dimensions of size 1 from the tensor x, resulting in a tensor y with reduced dimensions:
import torch# Create a tensor with dimensions (1, 1, 2, 1, 3)x = torch.rand(1, 1, 2, 1, 3)print("Original tensor size:", x.size())# Apply the .squeeze() operation to remove all dimensions of size '1'y = torch.squeeze(x)print("Squeezed tensor (all dims) size:", y.size())# Apply the .squeeze() operation with 'dim=1' to only remove the dimension at index '1'z = torch.squeeze(x, dim=1)print("Squeezed tensor (dim=1) size:", z.size())
The above code generates the following output:
Original tensor size: torch.Size([1, 1, 2, 1, 3])Squeezed tensor (all dims) size: torch.Size([2, 3])Squeezed tensor (dim=1) size: torch.Size([1, 2, 1, 3])
All contributors
- Anonymous contributor
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!
- Includes 6 Courses
- With Professional Certification
- Beginner Friendly.75 hours
- Learn how to use PyTorch to build, train, and test artificial neural networks in this course.
- Intermediate.3 hours