PyTorch .gather()
The .gather() function in PyTorch is a tensor operation that retrieves specific elements from a tensor along a specified axis. It is beneficial for selecting values based on a set of indices, making it ideal for applications in machine learning and data processing, where efficient data selection is critical.
Syntax
torch.gather(input, dim, index)
input: The source tensor from which values are gathered.dim: The dimension along which to gather values. This is the axis in theinputtensor where the selection occurs.index: A tensor of indices specifying which values to gather from theinputtensor along the specifieddim.
The function returns a tensor with the same shape as the index, where each value is gathered from the input tensor based on the specified indices.
Example
Here’s an example of how .gather() can be used to select elements from a tensor based on specified indices:
import torch# Define a source tensorinput_tensor = torch.tensor([[1, 2], [3, 4]])# Define the indices to gatherindex_tensor = torch.tensor([[0, 1], [1, 0]])# Gather elements from the source tensor along dimension 1output_tensor = torch.gather(input_tensor, 1, index_tensor)print(output_tensor)
This example results in the following output:
tensor([[1, 2],[4, 3]])
In this example, .gather() retrieves elements from input_tensor based on index_tensor along dimension 1. The result is a new tensor where values are selected from the original tensor according to the provided indices.
All contributors
- Anonymous contributor
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Machine Learning Data Scientists solve problems at scale, make predictions, find patterns, and more! They use Python, SQL, and algorithms.
- Includes 27 Courses
- With Professional Certification
- Beginner Friendly.95 hours
- Learn how to use PyTorch to build, train, and test artificial neural networks in this course.
- Intermediate.3 hours