.unravel_index()
Published Feb 10, 2025
Contribute to Docs
The .unravel_index()
function in PyTorch maps flat (1D) indices to multi-dimensional coordinates using a specified tensor shape. This is particularly useful when working with operations that return linear indices and to find the positions in the original tensor’s dimensions.
Syntax
torch.unravel_index(indices, shape)
indices
(Tensor): A 1D tensor containing flat indices to convert.shape
(Tuple): The dimensions of the target tensor (e.g.,(rows, columns)
).
Returns a tuple of tensors, where each tensor represents the coordinate values along a specific dimension of the target shape.
Example
Basic Usage
Converting flat indices [3, 1, 5]
into 2D coordinates for a tensor of shape (2, 3)
:
import torch# Flat indices and target shapeindices = torch.tensor([3, 1, 5])shape = (2, 3)# Get multi-dimensional coordinatescoords = torch.unravel_index(indices, shape)print("Coordinates (row, column):")for row, col in zip(*coords):print(f"({row}, {col})")
The above code will return the following output:
Coordinates (row, column):(1, 0)(0, 1)(1, 2)
3D Tensor Example
Convert flat indices to coordinates in a 3D tensor of shape (2, 2, 3)
:
import torchindices_3d = torch.tensor([7, 2])shape_3d = (2, 2, 3) # Dimensions: (depth, rows, columns)coords_3d = torch.unravel_index(indices_3d, shape_3d)print("Coordinates (depth, row, column):")for d, r, c in zip(*coords_3d):print(f"({d}, {r}, {c})")
The above code returns the following output:
Coordinates (depth, row, column):(1, 0, 1)(0, 0, 2)
For the 2D case (shape = (2, 3)
)
- Index 3 corresponds to row
1
(3 // 3 = 1
), column0
(3 % 3 = 0
). - Index 1 corresponds to row
0
(1 // 3 = 0
), column1
(1 % 3 = 1
). - Index 5 corresponds to row
1
(5 // 3 = 1
), column2
(5 % 3 = 2
).
For the 3D case (shape = (2, 2, 3)
)
- Index 7 is in depth
1
(7 // (2 * 3) = 1
), row0
((7 % 6) // 3 = 0
), column1
((7 % 6) % 3 = 1
). - Index 2 is in depth
0
(2 // (2 * 3) = 0
), row0
((2 % 6) // 3 = 0
), column2
((2 % 6) % 3 = 2
).
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Career path
Data Scientist: Machine Learning Specialist
Machine Learning Data Scientists solve problems at scale, make predictions, find patterns, and more! They use Python, SQL, and algorithms.Includes 27 CoursesWith Professional CertificationBeginner Friendly95 hours - Course
Learn Python 3
Learn the basics of Python 3.12, one of the most powerful, versatile, and in-demand programming languages today.With CertificateBeginner Friendly23 hours