PyTorch .take()
Published Nov 15, 2024
Contribute to Docs
The .take() function returns a 1D tensor containing elements from the given tensor at the specified indices. The resulting tensor will always be a 1D tensor irrespective of the size of the given tensor.
Syntax
torch.take(input,index)
input: The input tensor from which the elements will be selected.index: A 1D tensor containing the indices of the elements to extract frominput.
Example
The following example demonstrates the usage of .take() function:
import torch# Define a tensordata = torch.tensor([[4, 2, -1], [7, 8, 0]])# Define indices as a tensorindices = torch.tensor([0, 2, 5])# Use torch.take with data and indicesresult = torch.take(data, indices)print(result)
The code produces the following output:
[4,-1,0]
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!
- Includes 6 Courses
- With Professional Certification
- Beginner Friendly.75 hours
- Learn how to use PyTorch to build, train, and test artificial neural networks in this course.
- Intermediate.3 hours