PyTorch .unravel_index()
Published Feb 10, 2025
Contribute to Docs
The .unravel_index() function in PyTorch maps flat (1D) indices to multi-dimensional coordinates using a specified tensor shape. This is particularly useful when working with operations that return linear indices and to find the positions in the original tensor’s dimensions.
Syntax
torch.unravel_index(indices, shape)
indices(Tensor): A 1D tensor containing flat indices to convert.shape(Tuple): The dimensions of the target tensor (e.g.,(rows, columns)).
Returns a tuple of tensors, where each tensor represents the coordinate values along a specific dimension of the target shape.
Example
Basic Usage
Converting flat indices [3, 1, 5] into 2D coordinates for a tensor of shape (2, 3):
import torch# Flat indices and target shapeindices = torch.tensor([3, 1, 5])shape = (2, 3)# Get multi-dimensional coordinatescoords = torch.unravel_index(indices, shape)print("Coordinates (row, column):")for row, col in zip(*coords):print(f"({row}, {col})")
The above code will return the following output:
Coordinates (row, column):(1, 0)(0, 1)(1, 2)
3D Tensor Example
Convert flat indices to coordinates in a 3D tensor of shape (2, 2, 3):
import torchindices_3d = torch.tensor([7, 2])shape_3d = (2, 2, 3) # Dimensions: (depth, rows, columns)coords_3d = torch.unravel_index(indices_3d, shape_3d)print("Coordinates (depth, row, column):")for d, r, c in zip(*coords_3d):print(f"({d}, {r}, {c})")
The above code returns the following output:
Coordinates (depth, row, column):(1, 0, 1)(0, 0, 2)
For the 2D case (shape = (2, 3))
- Index 3 corresponds to row
1(3 // 3 = 1), column0(3 % 3 = 0). - Index 1 corresponds to row
0(1 // 3 = 0), column1(1 % 3 = 1). - Index 5 corresponds to row
1(5 // 3 = 1), column2(5 % 3 = 2).
For the 3D case (shape = (2, 2, 3))
- Index 7 is in depth
1(7 // (2 * 3) = 1), row0((7 % 6) // 3 = 0), column1((7 % 6) % 3 = 1). - Index 2 is in depth
0(2 // (2 * 3) = 0), row0((2 % 6) // 3 = 0), column2((2 % 6) % 3 = 2).
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Machine Learning Data Scientists solve problems at scale, make predictions, find patterns, and more! They use Python, SQL, and algorithms.
- Includes 27 Courses
- With Professional Certification
- Beginner Friendly.95 hours
- Learn the basics of Python 3.12, one of the most powerful, versatile, and in-demand programming languages today.
- With Certificate
- Beginner Friendly.24 hours