.scatter_reduce()

Anonymous contributor's avatar
Anonymous contributor
Published Jan 23, 2025
Contribute to Docs

In PyTorch, the .scatter_reduce() function reduces all values in the source tensor using the given reduction method (sum, prod, mean, amax, amin) and scatters the result to the input tensor.

Syntax

torch.scatter_reduce(input, dim, index, src, reduce, *, include_self=True)
  • input: The input tensor.
  • dim: The dimension along which to perform the reduction.
  • index: A tensor that specifies the indices for reduction.
  • src: The source tensor containing the values to be reduced.
  • reduce: The reduction method to perform (sum, prod, mean, amax, amin).
  • include_self: If True (default), the values from the self tensor (i.e., input) are used in the reduction.

Example

The following example demonstrates the usage of the .scatter_reduce() function:

import torch
# Create an input tensor
input = torch.tensor([21, 22, 23, 24])
# Create a source tensor containing the values
src = torch.tensor([11, 12, 13, 14, 15, 16])
# Create an index tensor containing the indices
index = torch.tensor([0, 1, 2, 2, 1, 0])
# Reduce the values along dimension 0 in 'src' and scatter the result to 'input'
res = torch.scatter_reduce(input, 0, index, src, reduce="sum")
# Print the resultant tensor
print(res)

The above code produces the following output:

tensor([48, 49, 50, 24])

All contributors

Contribute to Docs

Learn PyTorch on Codecademy