Decision Tree Algorithm for Classification

Decision Tree Algorithm for Classification

11 mins read2.8K Views Comment
Updated on Dec 19, 2022 00:17 IST

The article gives an introduction to decision tree algorithm for classification, with example in Python.

2022_04_Decision-tree-algorithm-for-classification.jpg

Imagine, you got a new job offer and need to decide whether you are going to take it or leave it. You consider several factors; starting from Salary, distance and commute time to the office, other perks and benefits, career growth, and so on. But you don’t choose all the factors at one time. Your brain processes the information through a series of if-else branches like the picture below:

You start thinking about the salary first; which becomes the main factor or starting point for your analysis. How Decision Tree Algorithm works

If salary criteria are met and it’s above $50K then do you think the commute time to the office is more than an hour or less? If the office is nearby and you can reach it easily, then you start thinking about whether the office offers you coffee and other perks. Gradually if all those conditions are met, you finally go ahead and accept the offer. 

The decision tree algorithm works exactly in the same fashion. In some sense, it’s the real-world replica of how a human brain makes decisions with series of clarifications and asks processing each one at a time in a sequential manner. 

Recommended online courses

Best-suited Machine Learning courses for you

Learn Machine Learning with these high-rated online courses

2.5 L
2 years
2.5 L
2 years
1.53 L
11 months
34.65 K
11 months
5.6 L
18 months
– / –
6 months
– / –
8 hours

Introduction to Decision tree algorithm: 

A decision tree is a machine learning algorithm that is used for classification and regression tasks. It works by creating a tree-like model of decisions and their possible consequences, with the goal of accurately predicting the outcome of a given input.

To create a decision tree, the algorithm begins by considering all the available features (also called “attributes”) of the input data. It then selects the feature that best splits the data into different classes or categories. This process is repeated for each split, with the algorithm choosing the feature that best divides the data at each step. The process continues until the tree is fully grown, or until a stopping criterion is reached (such as a maximum tree depth or a minimum number of samples in a leaf node).

Once the decision tree is created, it can be used to make predictions on new input data by following the path down the tree based on the feature values of the input data.

Decision trees are easy to interpret and understand, and they can handle both continuous and categorical data. However, they can be prone to overfitting, particularly if the tree is allowed to grow too deep. To mitigate this, techniques such as pruning (removing branches from the tree) or limiting the maximum depth of the tree can be used.

In machine learning, a decision tree may be defined as a non-parametric supervised algorithm. This algorithm uses a series of the if-else-based flowchart-like tree structures. This obtains the predictions that result from a sequence of feature-based splits. It starts with a root node, your first point of consideration: Salary. It ends with a decision made by leaves (Terminal nodes: Whether you accept or decline the offer) 

Important Terminologies related to a Decision Tree algorithm

Here are some terms and terminologies related to decision trees:

  1. Root node: The top node of a decision tree, representing the entire population or sample.
  2. Splitting: The process of dividing a node into two or more sub-nodes based on a feature or attribute value.
  3. Decision node: A node that represents a decision to be made based on the value of a feature or attribute.
  4. Leaf node: A terminal node that does not have any sub-nodes, representing a classification or prediction.
  5. Pruning: The process of removing branches from a decision tree to reduce overfitting and improve generalization to new data.
  6. Decision boundary: The line or plane that separates different classes or categories in the data.
  7. Gini index: A measure of the purity of the nodes in a decision tree, based on the proportion of samples belonging to a particular class.
  8. Information gain: A measure of the reduction in entropy (randomness or uncertainty) caused by splitting the data based on a particular feature.
  9. Overfitting: The phenomenon where a model fits the training data too well and does not generalize well to new data.
  10. Underfitting: The phenomenon where a model does not fit the training data well and therefore performs poorly on both the training and test data.

A decision tree in general is termed a Classification and Regression Tree (CART). It can be used for both classification problems as well as for continuous variable predictions too. However, in this article, we will restrict ourselves to a real-world example in classification only. 

Confusion Matrix in Machine Learning
Confusion Matrix in Machine Learning
Are you tired of your AI models getting confused? Untangle their mysteries with the Confusion Matrix, your secret weapon for accuracy! Decode True Positives, False Negatives, and more to uncover...read more
How to choose the Value of k in K-fold Cross-Validation
How to choose the Value of k in K-fold Cross-Validation
Cross-validation is a technique for evaluating a machine learning model and testing its performance. It is used commonly in applied ML tasks. It helps in comparing and selecting an appropriate...read more
Bias and Variance with Real-Life Examples
Bias and Variance with Real-Life Examples
This blog revolves around bias and variance and its tradeoff. These concepts are explained with respect to overfitting and underfitting with proper examples.

Application of Decision Tree

There are multiple real-life applications of Decision trees. Some examples include:

  1. Medical diagnosis: Make medical diagnoses based on a set of symptoms or test results.
  2. Credit approval: Banks and financial institutions can use decision trees to predict the likelihood of an individual defaulting on a loan or credit card based on their credit history and other factors.
  3. Marketing: Predict customer behavior and make targeted marketing campaigns based on factors such as age, income, and purchasing history.
  4. Fraud detection: Identify fraudulent transactions in areas such as credit card use or insurance claims.
  5. Oil reservoir characterization: Predict the characteristics of an oil reservoir based on data such as rock type and porosity.
  6. Customer churn prediction: Predict the likelihood of a customer churning (leaving a company) based on factors such as their usage patterns and customer service interactions.

Growing a tree

The decision tree algorithm starts at the root node and progresses downward in search of the purest set of data points. Speaking simply, the objective of a decision tree algorithm is to create splits at different nodes such that the resulting nodes (set of observations or points) are as homogeneous as possible.

As can be seen in the below figure, node A is an equal mix of blue and yellow dots and the most impure node in that sense, node C is all blue and the purest set of data points, and node B falls in-between node A and C. 

Decision tree algorithm for classification - growing a tree

Concept of Entropy and Splits: 

In decision tree analysis, entropy is a measure of the impurity or randomness of a set of data. It is commonly used to evaluate the quality of a split in a decision tree. The idea is that a split that results in pure, homogeneous subsets (low entropy) is more useful for making accurate predictions than a split that results in mixed or heterogeneous subsets (high entropy).

In information theory, entropy is a measure of the uncertainty or randomness of a random variable. In decision tree analysis, it is used to measure the impurity or randomness of a set of data. The entropy of a set of data is calculated using the following formula:

Entropy = – ∑(p(i) * log2(p(i)))

where p(i) is the proportion of data points in the set that belong to class i.

For example, consider a set of data with two classes, A and B. If the data is perfectly balanced, with 50% of the data points belonging to class A and 50% belonging to class B, the entropy would be 1. If the data is completely imbalanced, with all data points belonging to class A or all data points belonging to class B, the entropy would be 0.

In a decision tree, the entropy of a set of data is used to evaluate the quality of a split. A split that results in pure, homogeneous subsets (low entropy) is more useful for making accurate predictions than a split that results in mixed or heterogeneous subsets (high entropy). The goal of the decision tree is to find the split that results in the lowest possible entropy, so that the resulting subsets are as pure as possible.

Methods of splitting and growing a tree (concept of information gain): 

While building a decision tree it becomes very important in choosing the right feature or predictor for splitting and growing the treetop to down. To obtain the right set of features, the concept of Information gain is used which is developed on the principle of maximum entropy reduction while traversing from the top node to the bottom node by choosing the right set of features. 

concept of information gain

The concept of information gain is presented below:

Say in a real-life problem, you need to decide which factor is more important among Energy level and motivation for going to the gym. While exploring this the following set of responses in the form of a decision tree were observed. 

concept of information gain 2
concept of information gain 3
2022_03_image-185.jpg

Therefore, it’s evident that information gain or reduction in entropy would be higher if we chose Energy as the next feature. Therefore, the tree would select “Energy” as the next splitting criteria.

The split with the highest information gain will be taken as the first split. The process will continue until all children nodes are pure, or until the information gain is 0. That’s the reason, decision tree algorithms are termed greedy algorithms. They build the tree until each and every node becomes completely pure. 

However, growing a tree to reach the purest set of nodes may not be always feasible owing to computational challenges and overfitting problems on the training data. That is why the concept of pruning comes into the picture. The growth of the decision tree can be restricted by cutting the branches using hyperparameter tuning or by cost complexity pruning. The details of this are outside the scope of this article. However, we should have a clear understanding of these aspects as well while building a model using the CART algorithm. 

Example with python implementation: 

We will now focus on a real-world problem and implementation of a decision tree algorithm using the Scikit learn package of python. 

We introduce a dataset from the UCI Machine Learning repository. It’s called “Pima Indians Diabetes Database “. 

This dataset is originally from the National Institute of Diabetes and Digestive and Kidney Diseases. This dataset aims to diagnostically predict if a patient has diabetes, based on certain diagnostic measurements. Predictor variables include – the number of pregnancies the patient has had, their BMI, insulin level, age, and so on. The label variable or the dependent variable in this case is a 0/1 binary flag. Where 0 means the patient is non-diabetic and 1 means the patient is diabetic.

Sklearn decision tree

 
# Load libraries
import pandas as pd
from sklearn.tree import DecisionTreeClassifier # Import Decision Tree Classifier
from sklearn.model_selection import train_test_split # Import train_test_split function
from sklearn import metrics #Import scikit-learn metrics module for accuracy calculation
col_names = ['pregnant', 'glucose', 'bp', 'skin', 'insulin', 'bmi', 'pedigree', 'age', 'label']
# load dataset
pima = pd.read_csv("diabetes.csv", header=0, names=col_names)
pima.head()
Copy code

First, we import the required libraries and also set the column names for the dataset for easy interpretation. 

The first 5 rows of the dataset look like below: 

After loading the data, we understand the structure and variables. We determine the target and feature variables and divide the data into training and testing sets in 70:30.

 
#split dataset in features and target variable
feature_cols = ['pregnant', 'insulin', 'bmi', 'age','glucose','bp','pedigree']
X = pima[feature_cols] # Features
y = pima.label # Target variable
# Split dataset into training set and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1) # 70% training and 30% test
Copy code

Post the data set is split between the training and testing part we now apply the decision tree model on the training set; predict the results on the training and testing set both and then check the accuracy. We have also restricted the max_depth which is one of the key hyperparameters for a decision tree model. It depicts the longest path from the root node to the leaf node. By doing this we are also introducing the concept of hyper-parameter tuning in order to restrict the tree from overfitting.

 
# Create Decision Tree classifier object
clf = DecisionTreeClassifier(criterion="entropy", max_depth=3)
# Train Decision Tree Classifier
clf = clf.fit(X_train,y_train)
#Predict the response for test dataset
y_pred_train = clf.predict(X_train)
y_pred_test = clf.predict(X_test)
# Model Accuracy, how often is the classifier correct?
print("Accuracy in training set:",metrics.accuracy_score(y_train, y_pred_train))
print("Accuracy in testing set:",metrics.accuracy_score(y_test, y_pred_test))
Copy code

Model Accuracy

The model accuracy as obtained in the training and testing data set is as below:

With a maximum depth of 3, the decision tree algorithm here achieves a very similar accuracy. in both training and testing datasets. This can be considered a very good fit. 

In the next and final step; we visualize the model using pydotplus and graphviz packages within the sklearn framework. 

 
from sklearn.externals.six import StringIO
from IPython.display import Image
from sklearn.tree import export_graphviz
import pydotplus
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data,
filled=True, rounded=True,
special_characters=True, feature_names = feature_cols,class_names=['0','1'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('diabetes.png')
Image(graph.create_png())
Copy code

The final structure of the model looks like the below :

Interpretation of the model: As can be seen the branching starts with the first factor as “glucose”. Depending on the level of glucose in someone’s blood the next factor that the model considers is “BMI” or Body Mass Index. Finally, when someone’s BMI is also checked post glucose level the third most important variable here is “age” 

From the color-coding of the different nodes, it’s evident that with less entropy on the class=0 side. It suggests that in the non-diabetic side the nodes turn dark orange; whereas for the class=1 side or for the diabetic cases, with less entropy, the color turns dark blue. 

End notes: 

In this article, we have discussed and introduced the concept of a decision tree for the classification problem. We have also touched upon how a tree is built with a real-world example. I hope this would be a good starting point for everyone to learn and apply the decision tree. Further reading materials can be explored on hyperparameter tuning and application for continuous variables also. 


Top Trending Articles:

Data Analyst Interview Questions | Data Science Interview Questions | Machine Learning Applications | Big Data vs Machine Learning | Data Scientist vs Data Analyst | How to Become a Data Analyst | Data Science vs. Big Data vs. Data Analytics | What is Data Science | What is a Data Scientist | What is Data Analyst

About the Author

This is a collection of insightful articles from domain experts in the fields of Cloud Computing, DevOps, AWS, Data Science, Machine Learning, AI, and Natural Language Processing. The range of topics caters to upski... Read Full Bio