.gather()
The .gather()
function in PyTorch is a tensor operation that retrieves specific elements from a tensor along a specified axis. It is beneficial for selecting values based on a set of indices, making it ideal for applications in machine learning and data processing, where efficient data selection is critical.
Syntax
torch.gather(input, dim, index)
input
: The source tensor from which values are gathered.dim
: The dimension along which to gather values. This is the axis in theinput
tensor where the selection occurs.index
: A tensor of indices specifying which values to gather from theinput
tensor along the specifieddim
.
The function returns a tensor with the same shape as the index
, where each value is gathered from the input
tensor based on the specified indices.
Example
Here’s an example of how .gather()
can be used to select elements from a tensor based on specified indices:
import torch# Define a source tensorinput_tensor = torch.tensor([[1, 2], [3, 4]])# Define the indices to gatherindex_tensor = torch.tensor([[0, 1], [1, 0]])# Gather elements from the source tensor along dimension 1output_tensor = torch.gather(input_tensor, 1, index_tensor)print(output_tensor)
This example results in the following output:
tensor([[1, 2],[4, 3]])
In this example, .gather()
retrieves elements from input_tensor
based on index_tensor
along dimension 1
. The result is a new tensor where values are selected from the original tensor according to the provided indices.
All contributors
- Anonymous contributor
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Career path
Data Scientist: Machine Learning Specialist
Machine Learning Data Scientists solve problems at scale, make predictions, find patterns, and more! They use Python, SQL, and algorithms.Includes 27 CoursesWith Professional CertificationBeginner Friendly90 hours - Free course
Intro to PyTorch and Neural Networks
Learn how to use PyTorch to build, train, and test artificial neural networks in this course.Intermediate3 hours