PyTorch .dsplit()
Published Nov 30, 2024
Contribute to Docs
In PyTorch, the .dsplit() function splits a given tensor with three or more dimensions into given number of sub-tensors depthwise or along the third axis.
Syntax
torch.dsplit(input, sections)
input: The tensor to be split.sections: The number of sub-tensors that the input tensor is to be split into. The number should evenly divide the dimensions of the input tensor, otherwise it results in a runtime error.
Example
The following example demonstrates the usage of the .dsplit() function:
import torch# Create a 3D tensorten = torch.tensor([[[1, 2, 3],[4, 5, 6]],[[7, 8, 9],[8, 7, 6]]])# Split the tensor into three sub-tensorsres = torch.dsplit(ten, 3)print(res)
The above code produces the following output:
(tensor([[[1],[4]],[[7],[8]]]), tensor([[[2],[5]],[[8],[7]]]), tensor([[[3],[6]],[[9],[6]]]))
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Learn how to use PyTorch to build, train, and test artificial neural networks in this course.
- Intermediate.3 hours
- Build AI classification models with PyTorch using binary and multi-label techniques.
- With Certificate
- Beginner Friendly.3 hours