.dsplit()
Anonymous contributor
Published Nov 30, 2024
Contribute to Docs
In PyTorch, the .dsplit()
function splits a given tensor with three or more dimensions into given number of sub-tensors depthwise or along the third axis.
Syntax
torch.dsplit(input, sections)
input
: The tensor to be split.sections
: The number of sub-tensors that the input tensor is to be split into. The number should evenly divide the dimensions of the input tensor, otherwise it results in a runtime error.
Example
The following example demonstrates the usage of the .dsplit()
function:
import torch# Create a 3D tensorten = torch.tensor([[[1, 2, 3],[4, 5, 6]],[[7, 8, 9],[8, 7, 6]]])# Split the tensor into three sub-tensorsres = torch.dsplit(ten, 3)print(res)
The above code produces the following output:
(tensor([[[1],[4]],[[7],[8]]]), tensor([[[2],[5]],[[8],[7]]]), tensor([[[3],[6]],[[9],[6]]]))
All contributors
- Anonymous contributor
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Free course
Intro to PyTorch and Neural Networks
Learn how to use PyTorch to build, train, and test artificial neural networks in this course.Intermediate3 hours - Course
PyTorch for Classification
Build AI classification models with PyTorch using binary and multi-label techniques.With CertificateBeginner Friendly3 hours