index_add()
Published Nov 23, 2024
Contribute to Docs
In PyTorch, the .index_add()
function adds values to a tensor at specific indices along a specified dimension.
Syntax
torch.index_add(input, dim, index, source, *, alpha=1, out=None)
input
: The tensor to which values will be added.dim
: The dimension along which to index and add.index
: The indices of the elements to add to.source
: The tensor containing the values to add.alpha
: A scalar multiplier which is used before adding the source values. It is an optional parameter.out
: If provided, the result will be written to this tensor. It is also an optional parameter.
Example
The following example demonstrates the usage of the .index_add()
method:
import torch# Define the inputinput = torch.zeros(5)# Indices where the updates will occurindex = torch.tensor([0, 2, 4])# The tensor containing the values to be addedsource = torch.tensor([10, 20, 30]).float()# Add the values to specified indicesresult = torch.index_add(input, 0, index, source)print("Updated Tensor:", result)
The above code produces the following output:
Updated Tensor: tensor([10., 0., 20., 0., 30.])
The indices [0, 2, 4]
in the input tensor are updated with the corresponding values from the source tensor [10, 20, 30]
.
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Career path
Computer Science
Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!Includes 6 CoursesWith Professional CertificationBeginner Friendly75 hours - Free course
Intro to PyTorch and Neural Networks
Learn how to use PyTorch to build, train, and test artificial neural networks in this course.Intermediate3 hours