Skip to Content
Learn
Decision Trees
Gini Impurity

Consider the two trees below. Which tree would be more useful as a model that tries to predict whether someone would get an A in a class?

A tree where the leaf nodes have different types of classification A tree where the leaf nodes have only one type of classification

Let’s say you use the top tree. You’ll end up at a leaf node where the label is up for debate. The training data has labels from both classes! If you use the bottom tree, you’ll end up at a leaf where there’s only one type of label. There’s no debate at all! We’d be much more confident about our classification if we used the bottom tree.

This idea can be quantified by calculating the Gini impurity of a set of data points. To find the Gini impurity, start at 1 and subtract the squared percentage of each label in the set. For example, if a data set had three items of class A and one item of class B, the Gini impurity of the set would be

1(34)2(14)2=0.3751 - \bigg(\frac{3}{4}\bigg)^2 - \bigg(\frac{1}{4}\bigg)^2 = 0.375

If a data set has only one class, you’d end up with a Gini impurity of 0. The lower the impurity, the better the decision tree!

Instructions

1.

Let’s find the Gini impurity of the set of labels we’ve given you.

Let’s start by creating a variable named impurity and set it to 1.

2.

We now want to count up how many times every unique label is in the dataset. Python’s Counter object can do this quickly.

For example, given the following code:

lst = ["A", "A", "B"] counts = Counter(lst)

would result in counts storing this object:

Counter({"A": 2, "B": 1})

Create a counter object of labels‘ items named label_counts.

Print out label_counts to see if it matches what you expect.

3.

Let’s find the probability of each label given the dataset. Loop through each label in label_counts.

Inside the for loop, create a variable named probability_of_label. Set it equal to the label count divided by the total number of labels in the dataset.

For every label, the count associated with that label can be found at label_counts[label].

We can find the total number of labels in the dataset with len(labels).

4.

We now want to take probability_of_label, square it, and subtract it from impurity.

Inside the for loop, subtract probability_of_label squared from impurity.

In Python, you can square x by using x ** 2.

5.

Outside of the for loop, print impurity.

Test out some of the other labels that we’ve given you by uncommenting them. Which one do you expect to have the lowest impurity?

In the next exercise, we’ll put all of your code into a function. If you want a challenge, try creating the function yourself! Ours is named gini(), takes labels as a parameter, and returns impurity.

Folder Icon

Sign up to start coding

Already have an account?