.narrow()
Published Nov 9, 2024
Contribute to Docs
In PyTorch, the .narrow()
method selects a subsection of a tensor along a specified dimension. It returns a narrowed view without copying the underlying data, making it efficient for extracting specific sections of large tensors without incurring additional memory allocations.
Syntax
torch.narrow(input, dim, start, length)
input
: The tensor to be narrowed.dim
: The dimension along which the input tensor is to be narrowed.start
: The index where the narrowing begins. This can be a positive integer, a negative integer (to index from the end ofdim
) or a 0-dimensional integer tensor.length
: The number of elements to include from the starting position.
Example
The following example illustrates the usage of the .narrow()
method in various scenarios:
import torch# Create a 2D tensortensor_2d = torch.arange(1, 13).reshape(3, 4)print(f"Original 2D Tensor:\n{tensor_2d}")# Case 1: Narrowing along rows (dim=0)row_narrow = torch.narrow(tensor_2d, 0, 1, 2)print("\nCase 1: Narrow Along Rows (dim=0, start=1, length=2)")print(row_narrow)# Case 2: Narrowing along columns (dim=1)col_narrow = torch.narrow(tensor_2d, 1, 1, 2)print("\nCase 2: Narrow Along Columns (dim=1, start=1, length=2)")print(col_narrow)# Case 3: Extracting a single column (dim=1, length=1)single_col = torch.narrow(tensor_2d, 1, 2, 1)print("\nCase 3: Extract Single Column (dim=1, start=2, length=1)")print(single_col)# Case 4: Narrow with length extending beyond tensor's dimension# In this case, .narrow() raises an error because the sub-tensor's length exceeds the tensor's dimensiontry:error_narrow = torch.narrow(tensor_2d, 0, 1, 5)print("\nCase 4: Narrow With Length Exceeding Dimension Size")print(error_narrow)except RuntimeError as e:print("\nCase 4: RuntimeError -", e)# Case 5: Using a negative start index (dim=1, start=-3, length=2)negative_start_narrow = torch.narrow(tensor_2d, 1, -3, 2)print("\nCase 5: Negative Start Index (dim=1, start=-3, length=2)")print(negative_start_narrow)# Case 6: Using a 0-dimensional start indextensor_0_dim = torch.tensor(1) # 0-dimensional integer tensortensor_start_narrow = torch.narrow(tensor_2d, 0, tensor_0_dim, 2)print("\nCase 6: Start Index as a 0-Dim Tensor (dim=0, start=tensor(1), length=2)")print(tensor_start_narrow)
The above program gives the following output:
Original 2D Tensor:tensor([[ 1, 2, 3, 4],[ 5, 6, 7, 8],[ 9, 10, 11, 12]])Case 1: Narrow Along Rows (dim=0, start=1, length=2)tensor([[ 5, 6, 7, 8],[ 9, 10, 11, 12]])Case 2: Narrow Along Columns (dim=1, start=1, length=2)tensor([[ 2, 3],[ 6, 7],[10, 11]])Case 3: Extract Single Column (dim=1, start=2, length=1)tensor([[ 3],[ 7],[11]])Case 4: RuntimeError - start (1) + length (5) exceeds dimension size (3).Case 5: Negative Start Index (dim=1, start=-3, length=2)tensor([[ 2, 3],[ 6, 7],[10, 11]])Case 6: Start Index as a 0-Dim Tensor (dim=0, start=tensor(1), length=2)tensor([[ 5, 6, 7, 8],[ 9, 10, 11, 12]])
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Career path
Computer Science
Looking for an introduction to the theory behind programming? Master Python while learning data structures, algorithms, and more!Includes 6 CoursesWith Professional CertificationBeginner Friendly75 hours - Free course
Intro to PyTorch and Neural Networks
Learn how to use PyTorch to build, train, and test artificial neural networks in this course.Intermediate3 hours