PyTorch .masked_select()
In PyTorch, .masked_select() is a function that selects elements from an input tensor based on a boolean mask of the same shape. It returns a new 1D tensor containing the elements where the corresponding mask value is True.
Syntax
torch.masked_select(input, mask, *, out=None)
input: The input tensor from which elements will be selected.mask: A boolean tensor of the same shape as input, whereTrueindicates the elements to be selected.out(Optional): A tensor to store the result. If provided, the selected elements will be written to this tensor instead of creating a new one.
Example
Here’s an example of using .masked_select() in PyTorch:
import torch# Create an input tensorinput_tensor = torch.tensor([1, 2, 3, 4, 5])# Create a mask tensor with boolean valuesmask = torch.tensor([True, False, True, False, True])# Use masked_select to extract elements from the input tensor where the mask is Trueselected_elements = torch.masked_select(input_tensor, mask)# Print the selected elementsprint(selected_elements)
The code above generates the output as follows:
tensor([1, 3, 5])
In this example, the input_tensor contains elements [1, 2, 3, 4, 5], and the mask tensor contains boolean values [True, False, True, False, True]. The masked_select() function selects elements from the input_tensor where the corresponding mask value is True, resulting in the tensor [1, 3, 5].
The .masked_select() function is useful for filtering elements from a tensor based on conditions specified by the mask tensor. It can be applied in various scenarios, such as selecting specific elements for further processing, analysis, or model training.
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!
- Includes 6 Courses
- With Professional Certification
- Beginner Friendly.75 hours
- Learn how to use PyTorch to build, train, and test artificial neural networks in this course.
- Intermediate.3 hours