PyTorch .polygamma()

Priyanshjain10's avatar
Published Oct 31, 2025
Contribute to Docs

The .polygamma() function in PyTorch computes the polygamma function of input tensors element-wise. The polygamma function is the n-th derivative of the digamma function, where n is a non-negative integer.

  • Machine Learning Data Scientists solve problems at scale, make predictions, find patterns, and more! They use Python, SQL, and algorithms.
    • Includes 27 Courses
    • With Professional Certification
    • Beginner Friendly.
      95 hours

Syntax

torch.polygamma(n, input, *, out=None) → Tensor

Parameters:

  • n (int): The order of the polygamma function. When n=0, this is the digamma function; when n=1, this is the trigamma function.
  • input (Tensor): The input tensor containing values for which to compute the polygamma function.
  • out (Tensor, optional): The output tensor to store the result. Default is None.

Return value:

A tensor containing the computed polygamma values with the same shape as input.

Example 1: Digamma Function (n=0)

In this example, .polygamma() is used with n=0 to compute the digamma function (first derivative of the log-gamma function):

import torch
# Create a tensor
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
# Compute digamma (polygamma with n=0)
digamma_values = torch.polygamma(0, x)
print(digamma_values)

The output of this code is:

tensor([-0.5772, 0.4228, 0.9228, 1.2561])

Example 2: Trigamma Function (n=1)

In this example, .polygamma() is used with n=1 to compute the trigamma function (second derivative of the log-gamma function):

import torch
# Create input tensor
x = torch.tensor([1.0, 2.0, 3.0])
# Compute trigamma (polygamma with n=1)
trigamma_values = torch.polygamma(1, x)
print(trigamma_values)

The output of this code is:

tensor([1.6449, 0.6449, 0.3949])

Example 3: Higher Order Polygamma

In this example, .polygamma() is used with n=2 to compute the second-order polygamma function (derivative of the trigamma function):

import torch
# Compute polygamma of order 2
x = torch.tensor([2.0, 3.0, 4.0])
polygamma_2 = torch.polygamma(2, x)
print(polygamma_2)

The output of this code is:

tensor([-0.8224, -0.3540, -0.2164])

All contributors

Contribute to Docs

Learn PyTorch on Codecademy

  • Machine Learning Data Scientists solve problems at scale, make predictions, find patterns, and more! They use Python, SQL, and algorithms.
    • Includes 27 Courses
    • With Professional Certification
    • Beginner Friendly.
      95 hours