PyTorch index_add()
Published Nov 23, 2024
Contribute to Docs
In PyTorch, the .index_add() function adds values to a tensor at specific indices along a specified dimension.
Syntax
torch.index_add(input, dim, index, source, *, alpha=1, out=None)
input: The tensor to which values will be added.dim: The dimension along which to index and add.index: The indices of the elements to add to.source: The tensor containing the values to add.alpha: A scalar multiplier which is used before adding the source values. It is an optional parameter.out: If provided, the result will be written to this tensor. It is also an optional parameter.
Example
The following example demonstrates the usage of the .index_add() method:
import torch# Define the inputinput = torch.zeros(5)# Indices where the updates will occurindex = torch.tensor([0, 2, 4])# The tensor containing the values to be addedsource = torch.tensor([10, 20, 30]).float()# Add the values to specified indicesresult = torch.index_add(input, 0, index, source)print("Updated Tensor:", result)
The above code produces the following output:
Updated Tensor: tensor([10., 0., 20., 0., 30.])
The indices [0, 2, 4] in the input tensor are updated with the corresponding values from the source tensor [10, 20, 30].
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!
- Includes 6 Courses
- With Professional Certification
- Beginner Friendly.75 hours
- Learn how to use PyTorch to build, train, and test artificial neural networks in this course.
- Intermediate.3 hours