We can finally use our tree as a classifier! Given a new data point, we start at the top of the tree and follow the path of the tree until we hit a leaf. Once we get to a leaf, we’ll use the classes of the points from the training set to make a classification.
We’ve slightly changed the way our build_tree()
function works. Instead of returning a list of branches or a Counter
object, the build_tree()
function now returns a Leaf
object or an Internal_Node
object. We’ll explain how to use these objects in the instructions!
Let’s write a function that will use our tree to classify new points!
Instructions
We’ve created a tree named tree
using a lot of car data. Use the print_tree()
function with tree
as a parameter to see it.
Notice that the tree now knows which feature was used to split the data. This new information is contained in the Leaf
and Internal_Node
classes. This will come in handy when we write our classify function!
Comment out printing the tree once you get a sense of how large it is!
Let’s start writing the classify()
function. classify()
should take a datapoint
and a tree
as a parameter.
The first thing classify
should do is check to see if we’re at a leaf.
Check to see if tree
is a Leaf
by using the isinstance()
function.
For example, isinstance(a, list)
will be True
if a
is a list. You should check if tree
is a Leaf
.
If we’ve found a Leaf, that means we want to return the label with the highest count. The label counts are stored in tree.labels
.
You could find the label with the largest count by using a for loop, or by using this rather complicated line of code:
return max(tree.labels.items(), key=operator.itemgetter(1))[0]
If we’re not at a leaf, we want to find the branch that corresponds to our data point. For example, if we’re splitting on index 0
and our data point is ['med', 'low', '4', '2', 'big', 'low']
, we want to find the branch that contains all of the points with med
at index 0
.
To start, let’s find datapoint
‘s value of the feature we’re looking for. If datapoint
were the example above, and the feature we’re interested is 0
, this would be med
.
Outside the if statement, create a variable named value
and set it equal to datapoint[tree.feature]
. tree.feature
contains the index of the feature that we’re splitting on, so datapoint[tree.feature]
is the value at that index.
To help us check your code, return value
.
Start by deleting return value
.
Let’s now loop through all of the branches in the tree to find the one that has all the data points with value
at the correct index.
Your loop should look like this:
for branch in tree.branches:
Next, inside the loop, check to see if branch.value
is equal to value
. If it is, we’ve found the branch that we’re looking for! We want to now recursively call classify()
on that branch:
return classify(datapoint, branch)
We know that one of these branches will be the one we’re looking for, so we know that this return statement will happen once.
Finally, outside of your function, call classify()
using test_point
and tree
as parameters. Print the results. You should see a classification for this new point.