Learn

In this lesson, you’ll learn what random forests are, how the random forest algorithm works and how to implement it in scikit-learn.

We’ve seen that decision trees** can be powerful supervised machine learning models. However, they’re not without their weaknesses — decision trees are often prone to overfitting. We’ve discussed some strategies to minimize this problem, like pruning, but sometimes that isn’t enough. We need to find another way to generalize our trees. This is where the concept of a random forest comes in handy.

A random forest is an ensemble machine learning technique. A random forest contains many decision trees that all work together to classify new points. When a random forest is asked to classify a new point, the random forest gives that point to each of the decision trees. Each of those trees reports their classification and the random forest returns the most popular classification. It’s like every tree gets a vote, and the most popular classification wins.Some of the trees in the random forest may be overfit, but by making the prediction based on a large number of trees, overfitting will have less of an impact.

** A prerequisite for this lesson is the understanding of decision trees. Be sure to revisit the decision trees module if needed!

Instructions

The illustration to the right depicts a random forest used to predict if a student will get an A in a test or not based on an anonymous survey. The survey fields (the features here!) include hours the student slept, whether they expressed an intention to cheat or not, hours studied and their average grade so far. Here, the random forest is made up of eleven decision trees — the prediction outcome of A is outvoted 6:5!

Take this course for free

Mini Info Outline Icon
By signing up for Codecademy, you agree to Codecademy's Terms of Service & Privacy Policy.

Or sign up using:

Already have an account?