.select()
Published Dec 16, 2024
Contribute to Docs
The .select()
method in PyTorch returns a specific slice of a tensor along a specified dimension, reducing the dimensionality of the output tensor by one compared to the input tensor.
Syntax
torch.select(input, dim, index)
input
: The input tensor.dim
: The dimension along which to select.index
: The index of the slice to select along the specified dimension.
Example
The following example illustrates the usage of .select()
method:
import torch# 2D tensortensor = torch.tensor([[10, 20], [30, 40], [50, 60]])print("Input Tensor: ", tensor)# Select a row (dim=0)row = torch.select(tensor, 0, 1)print("\nSelected Row (dim=0, index=1):", row)# Select a column (dim=1)col = torch.select(tensor, 1, 0)print("\nSelected Column (dim=1, index=0):", col)
The above code gives the following output:
Input Tensor: tensor([[10, 20],[30, 40],[50, 60]])Selected Row (dim=0, index=1): tensor([30, 40])Selected Column (dim=1, index=0): tensor([10, 30, 50])
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Career path
Computer Science
Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!Includes 6 CoursesWith Professional CertificationBeginner Friendly75 hours - Free course
Intro to PyTorch and Neural Networks
Learn how to use PyTorch to build, train, and test artificial neural networks in this course.Intermediate3 hours