PyTorch .select()
Published Dec 16, 2024
Contribute to Docs
The .select() method in PyTorch returns a specific slice of a tensor along a specified dimension, reducing the dimensionality of the output tensor by one compared to the input tensor.
Syntax
torch.select(input, dim, index)
input: The input tensor.dim: The dimension along which to select.index: The index of the slice to select along the specified dimension.
Example
The following example illustrates the usage of .select() method:
import torch# 2D tensortensor = torch.tensor([[10, 20], [30, 40], [50, 60]])print("Input Tensor: ", tensor)# Select a row (dim=0)row = torch.select(tensor, 0, 1)print("\nSelected Row (dim=0, index=1):", row)# Select a column (dim=1)col = torch.select(tensor, 1, 0)print("\nSelected Column (dim=1, index=0):", col)
The above code gives the following output:
Input Tensor: tensor([[10, 20],[30, 40],[50, 60]])Selected Row (dim=0, index=1): tensor([30, 40])Selected Column (dim=1, index=0): tensor([10, 30, 50])
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!
- Includes 6 Courses
- With Professional Certification
- Beginner Friendly.75 hours
- Learn how to use PyTorch to build, train, and test artificial neural networks in this course.
- Intermediate.3 hours