PyTorch .select_scatter()
Published Jan 23, 2025
Contribute to Docs
In PyTorch, the .select_scatter() function inserts all values from the source tensor into the input tensor at the given indices.
Syntax
torch.select_scatter(input, src, dim, index)
input: The input tensor.src: The source tensor containing the values to be inserted.dim: The dimension along which the values are to be inserted.index: The starting index for inserting the values.
Example
The following example demonstrates the usage of the .select_scatter() function:
import torch# Create a 3x3 input tensor with all elements set to '0'input = torch.zeros(3, 3)# Create a source tensor containing the valuessrc = torch.tensor([4, 5, 6])# Insert the values along dimension 0 in 'input'res = torch.select_scatter(input, src, 0, 0)# Print the resultant tensorprint(res)
The above code produces the following output:
tensor([[4., 5., 6.],[0., 0., 0.],[0., 0., 0.]])
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Machine Learning Data Scientists solve problems at scale, make predictions, find patterns, and more! They use Python, SQL, and algorithms.
- Includes 27 Courses
- With Professional Certification
- Beginner Friendly.95 hours
- Learn how to use PyTorch to build, train, and test artificial neural networks in this course.
- Intermediate.3 hours