Now that we can find the best feature to split the dataset, we can repeat this process again and again to create the full tree. This is a recursive algorithm! We start with every data point from the training set, find the best feature to split the data, split the data based on that feature, and then recursively repeat the process again on each subset that was created from the split.
We’ll stop the recursion when we can no longer find a feature that results in any information gain. In other words, we want to create a leaf of the tree when we can’t find a way to split the data that makes purer subsets.
The leaf should keep track of the classes of the data points from the training set that ended up in the leaf. In our implementation, we’ll use a
Counter object to keep track of the counts of labels.
We’ll use these counts to make predictions about new data that we give the tree.
We’ve given you the function
find_best_split() that takes a set of data points and a set of labels.
The function returns the index of the feature that causes the best split and the information gain caused by that split.
For now, at the bottom of your code, call this function using
car_labels as parameters and store the values in variables named
Print those two variables. What was the best feature to split on and what was the information gain?
Let’s create a function called
build_tree() that takes
labels as parameters.
Move your call of
find_best_split() inside this function, but change the parameters from
0, return a
Counter object of
labels. We’ve reached the base case — there’s no way to gain any more information so we want to create a leaf.
After the if statement, we want to start working on the recursive case.
In the recursive case, we want to split the data into subsets using the best feature, and then recursively call the
build_tree() function on those subsets to create subtrees. Finally, we want to return a list of all those subtrees.
Let’s begin by splitting the data. You can do this by using the
split() function which takes three parameters — the data and labels that you want to split and the index of the feature you want to split on.
Store the result of the
split() function in two variables named
For now, return
data_subsets at the bottom of your function.
Before that final return statement, create an empty list named
branches. This list will store all of the subtrees we’re about to make from our recursive calls.
We now want to loop through all of the subsets of data and labels. Set up your for loop like this
for i in range(len(data_subsets)):
Inside the for loop, call
label_subsets[i] as parameters and
append the result to
Finally outside the for loop, return
branches instead of
Let’s test our function! At the bottom of your code outside of your function definition, call
car_labels as parameters and store the result in a variable named
We’ve written a function called
print_tree() that will help you visualize the tree. Call
tree as a parameter.