Decision Tree (Classification)
Let us first create a scenario: say I want to buy a house and wish to borrow a loan from the bank. Now before giving me the loan, the bank is going to look at my history record like my credit, what has it been like in the past? How much money do I make? (and maybe some other information) and use them to determine whether loaning me money is a risky thing or not (whether I'm going to default). So how can the bank do this with decision trees?
Decision trees are formed by a collection of rules (if-then statements that partition the data) based on variables available from the data set. So in the example above, a very simple decision tree model could look like this:
The algorithm works by starting at the top of the tree (the root node), then it will traverse down the branches of this decision tree and ask a series of questions. In the end it will reach the bottom of the tree (the leaf node) that contains the final outcome. For example, if somebody has a credit that's poor and his/her income is high, then the bank will say a Yes, we will give him/her the loan.
Our task now is to learn how to generate the tree to create these decision rules for us. Thankfully, the core method for learning a decision tree can be viewed as a recursive algorithm. A decision tree can be "learned" by splitting the dataset into subsets based on the input features/attributes' value. This process is repeated on each derived subset in a recursive manner called recursive partitioning:
Start at the tree's root node
Select the best rule/feature that splits the data into two subsets (child node) for the current node
Repeated step 2 on each of the derived subset until the tree can't be further splitted. As we'll later see, we can set restrictions to decide when the tree should stop growing.
There are a few additional details that we need to make more concrete. Including how to pick the rule/feature to split on and because it is a recursive algorithm, we have to figure out when to stop the recursion, in other words, when to not go and split another node in the tree.
Splitting criteria for classification trees
The first question is what is the best rule/feature to split on and how do we measure that? One way to determine this is by choosing the one that maximizes the Information Gain (IG) at each split.
: Information Gain
: feature to perform the split
: Some impurity measure that we'll look at in the subsequent section
: training subset of the parent node
, :training subset of the left/right child node
, : proportion of parent node samples that ended up in the left/right child node after the split. or . Where:
: number of samples in the parent node
: number of samples in the left child node
: number of samples in the right child node
Impurity
The two most common impurity measure are entropy and gini index.
Entropy
Entropy is defined as:
for all non-empty classes, , where:
is the proportion (or frequency or probability) of the samples that belong to class for a particular node
is the number of unique class labels
The entropy is therefore 0 if all samples at a node belong to the same class, and the entropy is maximal if we have an uniform class distribution. For example, in a binary class setting, the entropy is 0 if or . And if the classes are distributed uniformly with and the entropy is 1, which we can visualize by plotting the entropy for binary class setting below.
Gini Index
Gini Index is defined as:
Compared to Entropy, the maximum value of the Gini index is 0.5, which occurs when the classes are perfectly balanced in a node. On the other hand, the minimum value of the Gini index is 0 and occurs when there is only one class represented in a node (A node with a lower Gini index is said to be more "pure").
This time we plot Entropy and Gini index together to compare them against each other.
As we can see from the plot, there is not much differences (as in they both increase and decrease at similar range). In practice, Gini Index and Entropy typically yield very similar results and it is often not worth spending much time on evaluating decision tree models using different impurity criteria. As for which one to use, maybe consider Gini Index, because this way, we don’t need to compute the log, which can make it a bit computationly faster.
Decision trees can also be used on regression task. It's just instead of using gini index or entropy as the impurity function, we use criteria such as MSE (mean square error):
Where is the averages of the response at node , and is the number of observations that reached node . This is simply saying, we compute the differences between all observation's reponse to the average response, square it and take the average.
Concrete Example
Here we'll calculate the Entropy score by hand to hopefully make things a bit more concrete. Using the bank loan example again, suppose at a particular node, there are 80 observations, of whom 40 were classified as Yes (the bank will issue the loan) and 40 were classified as No.
We can first calculate the Entropy before making a split:
Suppose we try splitting on Income and the child nodes turn out to be.
Left (Income = high): 30 Yes and 10 No
Right (Income = low): 10 Yes and 30 No
Next we repeat the same process and evaluate the split based on splitting by Credit.
Left (Credit = excellent): 20 Yes and 0 No
Right (Credit = poot): 20 Yes and 40 No
In this case, it will choose Credit as the feature to split upon.
If we were to have more features, the decision tree algorithm will simply try every possible split, and will choose the split that maximizes the information gain. If the feature is a continuous variable, then we can simply get the unique values of that feature in a sorted order, then try all possible split values (threshold) by using cutoff point (average) between every two values (e.g. a unique value of 1, 2, 3 will result in trying the split on the value 1.5 and 2.5). Or to speed up computations, we can bin the unqiue values into buckets, and split on the buckets.
When To Stop Recursing
The other question that we need to address is when to stop the tree from growing. There are some early stopping criteria that is commonly used to prevent the tree from overfitting.
Maximum depth The length of the longest path from a root node to a leaf node will not exceed this value. This is the most commonly tuned hyperparameter for tree-based method
Minimum sample split: The minimum number of samples required to split a node should be greater than this number
Minimum information gain The minimum information gain required for splitting on the best feature
And that's pretty much it for classification trees! For a more visual appealing explanation, the following link this a website that uses interactive visualization to demonstrate how decision trees work. A Visual Introduction to Machine Learning
Implementation
With all of that in mind, the following section implements a toy classification tree algorithm.
We will load the Iris dataset, and use it as a sample dataset to test our algorithm. This data sets consists of 3 different types of irises’ (Setosa, Versicolour, and Virginica). It is stored as a 150x4 numpy.ndarray, where the rows are the samples and the columns being Sepal Length, Sepal Width, Petal Length and Petal Width.
Advantages of decision trees:
Features don't require scaling or normalization
Great at dealing with data that have lots of categorical features
Can be displayed graphically, thus making it highly interpretable (in the next code chunk)
It is non-parametric, thus it will outperform linear models if relationship between features and response is highly non-linear
For visualizing the decision tree, you might need to have graphviz installed. For the mac user, try doing brew install graphviz
or follow the instructions in this link.