.masked_select()
In PyTorch, .masked_select()
is a function that selects elements from an input tensor based on a boolean mask of the same shape. It returns a new 1D tensor containing the elements where the corresponding mask value is True
.
Syntax
torch.masked_select(input, mask, *, out=None)
input
: The input tensor from which elements will be selected.mask
: A boolean tensor of the same shape as input, whereTrue
indicates the elements to be selected.out
(Optional): A tensor to store the result. If provided, the selected elements will be written to this tensor instead of creating a new one.
Example
Here’s an example of using .masked_select()
in PyTorch:
import torch# Create an input tensorinput_tensor = torch.tensor([1, 2, 3, 4, 5])# Create a mask tensor with boolean valuesmask = torch.tensor([True, False, True, False, True])# Use masked_select to extract elements from the input tensor where the mask is Trueselected_elements = torch.masked_select(input_tensor, mask)# Print the selected elementsprint(selected_elements)
The code above generates the output as follows:
tensor([1, 3, 5])
In this example, the input_tensor
contains elements [1, 2, 3, 4, 5]
, and the mask
tensor contains boolean values [True, False, True, False, True]
. The masked_select()
function selects elements from the input_tensor
where the corresponding mask value is True
, resulting in the tensor [1, 3, 5]
.
The .masked_select()
function is useful for filtering elements from a tensor based on conditions specified by the mask tensor. It can be applied in various scenarios, such as selecting specific elements for further processing, analysis, or model training.
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Career path
Computer Science
Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!Includes 6 CoursesWith Professional CertificationBeginner Friendly75 hours - Free course
Intro to PyTorch and Neural Networks
Learn how to use PyTorch to build, train, and test artificial neural networks in this course.Intermediate3 hours