We can finally use our tree as a classifier! Given a new data point, we start at the top of the tree and follow the path of the tree until we hit a leaf. Once we get to a leaf, we’ll use the classes of the points from the training set to make a classification.
We’ve slightly changed the way our
build_tree() function works. Instead of returning a list of branches or a
Counter object, the
build_tree() function now returns a
Leaf object or an
Internal_Node object. We’ll explain how to use these objects in the instructions!
Let’s write a function that will use our tree to classify new points!
We’ve created a tree named
tree using a lot of car data. Use the
print_tree() function with
tree as a parameter to see it.
Notice that the tree now knows which feature was used to split the data. This new information is contained in the
Internal_Node classes. This will come in handy when we write our classify function!
Comment out printing the tree once you get a sense of how large it is!
Let’s start writing the
classify() should take a
datapoint and a
tree as a parameter.
The first thing
classify should do is check to see if we’re at a leaf.
Check to see if
tree is a
Leaf by using the
isinstance(a, list) will be
a is a list. You should check if
tree is a
If we’ve found a Leaf, that means we want to return the label with the highest count. The label counts are stored in
You could find the label with the largest count by using a for loop, or by using this rather complicated line of code:
return max(tree.labels.items(), key=operator.itemgetter(1))
If we’re not at a leaf, we want to find the branch that corresponds to our data point. For example, if we’re splitting on index
0 and our data point is
['med', 'low', '4', '2', 'big', 'low'], we want to find the branch that contains all of the points with
med at index
To start, let’s find
datapoint‘s value of the feature we’re looking for. If
datapoint were the example above, and the feature we’re interested is
0, this would be
Outside the if statement, create a variable named
value and set it equal to
tree.feature contains the index of the feature that we’re splitting on, so
datapoint[tree.feature] is the value at that index.
To help us check your code, return
Start by deleting
Let’s now loop through all of the branches in the tree to find the one that has all the data points with
value at the correct index.
Your loop should look like this:
for branch in tree.branches:
Next, inside the loop, check to see if
branch.value is equal to
value. If it is, we’ve found the branch that we’re looking for! We want to now recursively call
classify() on that branch:
return classify(datapoint, branch)
We know that one of these branches will be the one we’re looking for, so we know that this return statement will happen once.
Finally, outside of your function, call
tree as parameters. Print the results. You should see a classification for this new point.