.scatter()
Anonymous contributor
Published Dec 23, 2024
Contribute to Docs
In PyTorch, the .scatter()
function writes values from a source (tensor or scalar) into specific locations of a tensor along a specified dimension, based on given indices.
Syntax
torch.scatter(ten, dim, index, src)
ten
: The tensor where the values are to be inserted.dim
: The dimension along which the values are to be inserted.index
: The tensor which specifies the locations inten
where the values are to be inserted.src
: The tensor which contains the values to be inserted.
Example
The following example demonstrates the usage of the .scatter()
function:
import torch# Create a tensorten = torch.tensor([[11, 12, 13, 14, 15], [16, 17, 18, 19, 20]])# Create a tensor containing the locationsindex = torch.tensor([[0, 2], [1, 3]])# Create a tensor containing the valuessrc = torch.tensor([[21, 23], [27, 29]])# Insert the given values into specified locations along dimension 1 in the original tensorres = torch.scatter(ten, 1, index, src)# Print the resultant tensorprint(res)
The above code produces the following output:
tensor([[21, 12, 23, 14, 15],[16, 27, 18, 29, 20]])
All contributors
- Anonymous contributor
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Career path
Data Scientist: Machine Learning Specialist
Machine Learning Data Scientists solve problems at scale, make predictions, find patterns, and more! They use Python, SQL, and algorithms.Includes 27 CoursesWith Professional CertificationBeginner Friendly90 hours - Free course
Intro to PyTorch and Neural Networks
Learn how to use PyTorch to build, train, and test artificial neural networks in this course.Intermediate3 hours