.bernoulli()
Anonymous contributor
Published Feb 21, 2025
Contribute to Docs
In PyTorch, the .bernoulli()
function is used to generate a tensor with binary values (0s and 1s) based on the probabilities provided in an input tensor. Each element in the output tensor is drawn from a Bernoulli distribution, where the probability of drawing a 1 is given by the corresponding element in the input tensor.
Syntax
torch.bernoulli(input, *, gen=None, out=None)
input
: A tensor containing probabilities (values between 0 and 1).gen
(Optional): Generates a pseudorandom number for sampling. Defaults toNone
.out
(Optional): The output tensor to store the result.
Example
The following example demonstrates the usage of the .bernoulli()
function:
import torch# Create a tensor containing probabilitiesprob = torch.tensor([0.3, 0.6, 0.9, 0.4, 0.8, 0.5])# Generate samples from the Bernoulli distributionres = torch.bernoulli(prob)# Print the resultant tensorprint(res)
The above code produces the following output:
tensor([0., 1., 1., 1., 1., 1.])
Note: Since the
.bernoulli()
function samples independently for each element, the output may vary each time the code is run.
All contributors
- Anonymous contributor
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.
Learn PyTorch on Codecademy
- Career path
Data Scientist: Machine Learning Specialist
Machine Learning Data Scientists solve problems at scale, make predictions, find patterns, and more! They use Python, SQL, and algorithms.Includes 27 CoursesWith Professional CertificationBeginner Friendly95 hours - Free course
Intro to PyTorch and Neural Networks
Learn how to use PyTorch to build, train, and test artificial neural networks in this course.Intermediate3 hours