Consider the two trees below. Which tree would be more useful as a model that tries to predict whether someone would get an A in a class?
Let’s say you use the top tree. You’ll end up at a leaf node where the label is up for debate. The training data has labels from both classes! If you use the bottom tree, you’ll end up at a leaf where there’s only one type of label. There’s no debate at all! We’d be much more confident about our classification if we used the bottom tree.
This idea can be quantified by calculating the Gini impurity of a set of data points. To find the Gini impurity, start at 1
and subtract the squared percentage of each label in the set. For example, if a data set had three items of class A
and one item of class B
, the Gini impurity of the set would be
If a data set has only one class, you’d end up with a Gini impurity of 0
. The lower the impurity, the better the decision tree!
Instructions
Let’s find the Gini impurity of the set of labels
we’ve given you.
Let’s start by creating a variable named impurity
and set it to 1
.
We now want to count up how many times every unique label is in the dataset. Python’s Counter
object can do this quickly.
For example, given the following code:
lst = ["A", "A", "B"] counts = Counter(lst)
would result in counts
storing this object:
Counter({"A": 2, "B": 1})
Create a counter object of labels
‘ items named label_counts
.
Print out label_counts
to see if it matches what you expect.
Let’s find the probability of each label
given the dataset. Loop through each label
in label_counts
.
Inside the for loop, create a variable named probability_of_label
. Set it equal to the label
count divided by the total number of labels in the dataset.
For every label
, the count associated with that label
can be found at label_counts[label]
.
We can find the total number of labels in the dataset with len(labels)
.
We now want to take probability_of_label
, square it, and subtract it from impurity
.
Inside the for loop, subtract probability_of_label
squared from impurity
.
In Python, you can square x
by using x ** 2
.
Outside of the for loop, print impurity
.
Test out some of the other labels
that we’ve given you by uncommenting them. Which one do you expect to have the lowest impurity?
In the next exercise, we’ll put all of your code into a function. If you want a challenge, try creating the function yourself! Ours is named gini()
, takes labels
as a parameter, and returns impurity
.