.take()
Anonymous contributor
Published Nov 15, 2024
Contribute to Docs
The .take()
function returns a 1D tensor containing elements from the given tensor at the specified indices. The resulting tensor will always be a 1D tensor irrespective of the size of the given tensor.
Syntax
torch.take(input,index)
input
: The input tensor from which the elements will be selected.index
: A 1D tensor containing the indices of the elements to extract frominput
.
Example
The following example demonstrates the usage of .take()
function:
import torch# Define a tensordata = torch.tensor([[4, 2, -1], [7, 8, 0]])# Define indices as a tensorindices = torch.tensor([0, 2, 5])# Use torch.take with data and indicesresult = torch.take(data, indices)print(result)
The code produces the following output:
[4,-1,0]
All contributors
- Anonymous contributor
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Career path
Computer Science
Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!Includes 6 CoursesWith Professional CertificationBeginner Friendly75 hours - Free course
Intro to PyTorch and Neural Networks
Learn how to use PyTorch to build, train, and test artificial neural networks in this course.Intermediate3 hours