.split()
Published Oct 28, 2024
Contribute to Docs
In PyTorch, the .split()
function is used to split a tensor into chunks of specified sizes along a specified dimension and returns a tuple of tensors. It is useful for processing smaller segments of data independently.
Syntax
torch.split(tensor, split_size_or_sections, dim=0)
tensor
: The tensor to be split. This is a required parameter.split_size_or_sections
: Specifies the size of each chunk. If an integer, it defines the number of elements in each chunk. If a list of integers, it specifies the exact size of each chunk in order.dim
(Optional): The dimension along which to split the tensor. The default value is 0.
Example
The following example illustrates the usage of .split()
in various scenarios:
import torch# a 1D tensor and a 2D tensortensor_1d = torch.arange(1, 13)tensor_2d = tensor_1d.reshape(2, 6)print(f"Input 1D Tensor - {tensor_1d}")# Case 1: Splitting into equal chunks of size 3equal_chunks = torch.split(tensor_1d, 3)print("\nCase 1: Equal Chunks (Size = 3)")for chunk in equal_chunks:print(chunk)# Case 2: Splitting into unequal chunks of sizes [4, 3, 5]unequal_chunks = torch.split(tensor_1d, [4, 3, 5])print("\nCase 2: Unequal Chunks (Sizes = [4, 3, 5])")for chunk in unequal_chunks:print(chunk)# Case 3: Attempting to split with non-divisible chunk sizenon_divisible_split = torch.split(tensor_1d, 5)print("\nCase 3: Non-Divisible Chunk Size (Size = 5)")for chunk in non_divisible_split:print(chunk)print(f"\nInput 2D Tensor - {tensor_2d}")# Case 4: Splitting a 2D tensor along rows (dim=0)row_split = torch.split(tensor_2d, 1, dim=0)print("\nCase 4: Split Along Rows (dim=0, size=1)")for chunk in row_split:print(chunk)# Case 5: Splitting a 2D tensor along columns (dim=1)col_split = torch.split(tensor_2d, 3, dim=1)print("\nCase 5: Split Along Columns (dim=1, size=3)")for chunk in col_split:print(chunk)# Case 6: Splitting a 2D tensor into unequal sizes along columns (dim=1)uneven_split_2d = torch.split(tensor_2d, [1, 3, 2], dim=1)print("\nCase 6: Unequal Split Sizes on 2D Tensor ([1, 3, 2], dim=1)")for chunk in uneven_split_2d:print(chunk)# Case 7: Unequal split with sizes that do not sum up to the tensor size# In this case, .split() raises an error when the sizes don't add up correctly.print("\nCase 7: Unequal Split Sizes That Do Not Sum Upto the 2D Tensor Size ([1, 3, 3], dim=1)")try:error_split = torch.split(tensor_2d, [1, 3, 3], dim=1)for chunk in error_split:print(chunk)except RuntimeError as e:print(f"RuntimeError - {e}")
The above program gives the following output:
Input 1D Tensor - tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])Case 1: Equal Chunks (Size = 3)tensor([1, 2, 3])tensor([4, 5, 6])tensor([7, 8, 9])tensor([10, 11, 12])Case 2: Unequal Chunks (Sizes = [4, 3, 5])tensor([1, 2, 3, 4])tensor([5, 6, 7])tensor([ 8, 9, 10, 11, 12])Case 3: Non-Divisible Chunk Size (Size = 5)tensor([1, 2, 3, 4, 5])tensor([ 6, 7, 8, 9, 10])tensor([11, 12])Input 2D Tensor - tensor([[ 1, 2, 3, 4, 5, 6],[ 7, 8, 9, 10, 11, 12]])Case 4: Split Along Rows (dim=0, size=1)tensor([[1, 2, 3, 4, 5, 6]])tensor([[ 7, 8, 9, 10, 11, 12]])Case 5: Split Along Columns (dim=1, size=3)tensor([[1, 2, 3],[7, 8, 9]])tensor([[ 4, 5, 6],[10, 11, 12]])Case 6: Unequal Split Sizes on 2D Tensor ([1, 3, 2], dim=1)tensor([[1],[7]])tensor([[ 2, 3, 4],[ 8, 9, 10]])tensor([[ 5, 6],[11, 12]])Case 7: Unequal Split Sizes That Do Not Sum Upto the 2D Tensor Size ([1, 3, 3], dim=1)RuntimeError - split_with_sizes expects split_sizes to sum exactly to 6 (input tensor's size at dimension 1), but got split_sizes=[1, 3, 3]
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Career path
Computer Science
Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!Includes 6 CoursesWith Professional CertificationBeginner Friendly75 hours - Free course
Intro to PyTorch and Neural Networks
Learn how to use PyTorch to build, train, and test artificial neural networks in this course.Intermediate3 hours