PyTorch .index_select()
Published Dec 16, 2024
Contribute to Docs
In PyTorch, the .index_select() function extracts specific elements from a tensor along a specified dimension based on indices and returns a new tensor.
Syntax
torch.index_select(input, dim, index, out=None)
input: The input tensor to select elements from.dim: The dimension along which to index (e.g.,0for rows,1for columns).index: A 1-D tensor containing the indices to select from the specified dimension.out(Optional): The output tensor to store the result. If specified, the result will be written to this tensor instead of creating a new one.
Note: If the
outparameter is provided, the function stores the result in the given tensor. Otherwise, a new tensor is created and returned.
Example
The following example demonstrates the use of .index_select():
import torch# Define a 2D tensorten = torch.tensor([[1, 2, 3, -8],[4, 3, 8, 0],[-1, 7, 6, 3],[5, 6, 9, 0]])# Define indices to selectindices = torch.tensor([0, 2])# Select rows from the tensor (dimension 0)r = torch.index_select(ten, 0, indices)print("Rows selected:")print(r)# Select columns from the tensor (dimension 1)c = torch.index_select(ten, 1, indices)print("Columns selected:")print(c)
This will generate the output as follows:
Rows selected:tensor([[ 1, 2, 3, -8],[-1, 7, 6, 3]])Columns selected:tensor([[ 1, 3],[ 4, 8],[-1, 6],[ 5, 9]])
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!
- Includes 6 Courses
- With Professional Certification
- Beginner Friendly.75 hours
- Learn how to use PyTorch to build, train, and test artificial neural networks in this course.
- Intermediate.3 hours