.take_along_dim()

muhammedazhar's avatar
Published Jan 20, 2025
Contribute to Docs

The .take_along_dim() function in PyTorch is used to select elements from a tensor along a specified dimension. This operation is essential for advanced indexing operations and manipulating multi-dimensional tensors in deep learning applications.

Similar to .take(), which extracts elements based on indices and always returns a 1D tensor, .take_along_dim() provides a more flexible approach by allowing indexing along a specific dimension while preserving the tensor’s shape.

Syntax

torch.take_along_dim(input, indices, dim)
  • input: The source tensor from which elements will be selected.
  • indices: A tensor of indices specifying which elements to select along the specified dimension.
  • dim: The dimension along which to perform the selection.

It returns a new tensor with the same dimensionality as the input tensor, containing the selected elements.

Example

Here is a basic usage example of .take_along_dim() in PyTorch to select elements along a specific dimension:

import torch
# Create a source tensor
input_tensor = torch.tensor([[10, 20, 30],
[40, 50, 60]])
# Define indices for selection
indices = torch.tensor([[2, 1, 0],
[1, 0, 2]])
# Select elements along dimension 1
result = torch.take_along_dim(input_tensor, indices, dim=1)
# Print the result
print(result)

The following will be the output of the above code:

tensor([[30, 20, 10],
[50, 40, 60]])

Moreover, the function can also be used to select elements along a specific dimension in a multi-dimensional tensor. For instance, the following example can be considered:

import torch
# Create a 3D tensor
input_tensor = torch.tensor([[[1, 2], [3, 4]],
[[5, 6], [7, 8]]])
# Define indices for selection
indices = torch.tensor([[[0, 1], [1, 0]],
[[0, 0], [1, 1]]])
# Select elements along the last dimension
result = torch.take_along_dim(input_tensor, indices, dim=2)
# Print the result
print(result)

The output of the above code will be:

tensor([[[1, 2],
[4, 3]],
[[5, 5],
[8, 8]]])

Key Features

Here are some key features of the .take_along_dim() function:

  • Preserves tensor dimensionality during selection
  • Supports batch operations
  • Works with any number of dimensions
  • Maintains gradient information for backpropagation

Common Use Cases

Here are some common use cases of the .take_along_dim() function:

  • Sorting tensor elements
  • Implementing attention mechanisms
  • Selecting top-k elements
  • Custom pooling operations

Notes

Here are some notes about the .take_along_dim() function:

  • The indices tensor must have the same shape as the input tensor
  • Supports automatic differentiation
  • More flexible than the .take() function for multi-dimensional operations
  • Memory-efficient for large tensor operations

All contributors

Contribute to Docs

Learn PyTorch on Codecademy