PyTorch .where()
In PyTorch, the .where function returns a new tensor with elements chosen based on a condition, selecting values from one source when the condition is met and from another when it is not.
Syntax
torch.where(condition, input, other) → Tensor
condition: The boolean tensor that controls selection.input: The tensor to select elements from when the condition isTrue.other: The tensor to select elements from when the condition isFalse.
It returns a tensor of elements selected from either input or other, based on the condition.
Example
The following example demonstrates the usage of the .where() function:
import torch# Define tensorscondition = torch.tensor([[True, False], [False, True]])x = torch.tensor([[1, 2], [3, 4]])y = torch.tensor([[9, 8], [7, 6]])# Select elements based on the conditionres = torch.where(condition, x, y)# Print the resultprint(res)
The above code produces the following output:
tensor([[1, 8],[7, 4]])
In this example, the .where() function selects elements from the x tensor where the condition tensor is True and from the y tensor where the condition tensor is False. The resulting tensor is created by selecting elements from x or y based on the condition tensor.
The .where() function is particularly useful in conditional operations where element-wise selection between tensors is required.
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Learn to build machine learning models with Python.
- Includes 10 Courses
- With Certificate
- Beginner Friendly.23 hours
- Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!
- Includes 6 Courses
- With Professional Certification
- Beginner Friendly.75 hours