.where()
In PyTorch, the .where
function returns a new tensor with elements chosen based on a condition, selecting values from one source when the condition is met and from another when it is not.
Syntax
torch.where(condition, input, other) → Tensor
condition
: The boolean tensor that controls selection.input
: The tensor to select elements from when the condition isTrue
.other
: The tensor to select elements from when the condition isFalse
.
It returns a tensor of elements selected from either input
or other
, based on the condition.
Example
The following example demonstrates the usage of the .where()
function:
import torch# Define tensorscondition = torch.tensor([[True, False], [False, True]])x = torch.tensor([[1, 2], [3, 4]])y = torch.tensor([[9, 8], [7, 6]])# Select elements based on the conditionres = torch.where(condition, x, y)# Print the resultprint(res)
The above code produces the following output:
tensor([[1, 8],[7, 4]])
In this example, the .where()
function selects elements from the x
tensor where the condition tensor is True
and from the y
tensor where the condition tensor is False
. The resulting tensor is created by selecting elements from x
or y
based on the condition tensor.
The .where()
function is particularly useful in conditional operations where element-wise selection between tensors is required.
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Skill path
Build a Machine Learning Model
Learn to build machine learning models with Python.Includes 10 CoursesWith CertificateBeginner Friendly23 hours - Career path
Computer Science
Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!Includes 6 CoursesWith Professional CertificationBeginner Friendly75 hours - Free course
Intro to PyTorch and Neural Networks
Learn how to use PyTorch to build, train, and test artificial neural networks in this course.Intermediate3 hours