Now that we have an understanding of how decision trees are created and used, let’s talk about some of their limitations.
One problem with the way we’re currently making our decision trees is that our trees aren’t always globablly optimal. This means that there might be a better tree out there somewhere that produces better results. But wait, why did we go through all that work of finding information gain if it’s not producing the best possible tree?
Our current strategy of creating trees is greedy. We assume that the best way to create a tree is to find the feature that will result in the largest information gain right now and split on that feature. We never consider the ramifications of that split further down the tree. It’s possible that if we split on a suboptimal feature right now, we would find even better splits later on. Unfortunately, finding a globally optimal tree is an extremely difficult task, and finding a tree using our greedy approach is a reasonable substitute.
Another problem with our trees is that they potentially overfit the data. This means that the structure of the tree is too dependent on the training data and doesn’t accurately represent the way the data in the real world looks like. In general, larger trees tend to overfit the data more. As the tree gets bigger, it becomes more tuned to the training data and it loses a more generalized understanding of the real world data.
One way to solve this problem is to prune the tree. The goal of pruning is to shrink the size of the tree. There are a few different pruning strategies, and we won’t go into the details of them here.
scikit-learn currently doesn’t prune the tree by default, however we can dig into the code a bit to prune it ourselves.
We’ve created a decision tree classifier for you and printed its accuracy. Let’s see how big this tree is.
If your classifier is named
classifier, you can find the depth of the tree by printing
Print the depth of
classifier‘s decision tree.
Take note of the accuracy as well.
classifier should have a depth of
12. Let’s prune it! When you create
classifier, set the parameter
max_depth equal to
What is the accuracy of the classifier after pruning the tree from size 12 to size 11?