.tensor_split()
Published Jan 18, 2025
Contribute to Docs
In PyTorch, the .tensor_split()
function splits a tensor into multiple sub-tensors along a specified dimension. If the tensor cannot be split evenly, the function distributes the elements across the sub-tensors as evenly as possible.
Syntax
torch.tensor_split(input, indices_or_sections, dim=0)
input
: The tensor to be split.indices_or_sections
:- If int: The number of sub-tensors to split the input tensor into. If the split is uneven, the resulting sub-tensors will differ in size to distribute elements as evenly as possible.
- If list or tuple of ints: The indices at which to split the tensor along the specified dimension.
dim
: The dimension along which to split the tensor. Default is0
.
Example
The following example demonstrates the use of the .tensor_split()
function:
import torch# Create a one-dimensional tensorx = torch.arange(10)# Split the tensor into 2 partsresult = torch.tensor_split(x, 2)# Print the resultprint(result)
The code above gives the output as follows:
(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))
The output is a list of two sub-tensors, where the input tensor is evenly split into two parts along its only dimension.
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Career path
Data Scientist: Machine Learning Specialist
Machine Learning Data Scientists solve problems at scale, make predictions, find patterns, and more! They use Python, SQL, and algorithms.Includes 27 CoursesWith Professional CertificationBeginner Friendly90 hours - Free course
Intro to PyTorch and Neural Networks
Learn how to use PyTorch to build, train, and test artificial neural networks in this course.Intermediate3 hours