Matplotlib in Machine Learning – Shiksha Online
In applied Statistics and Machine Learning, Data Visualization is one of the most important skills that helps in the qualitative understanding of the data at hand. This proves to help explore and extract relevant information from the data by identifying patterns, relationships, outliers, and much more. The article explores the concept of Matplotlib in Machine Learning.
Visualizations are the easiest way to analyze and intake information. Data Visualization also gives way to high-level data analysis in Exploratory Data Analysis (EDA). Python features multiple data visualization libraries – the most popular and widely used one among them being the Matplotlib Library. In this blog, we will be covering Matplotlib in Machine Learning in the following sections:
- Introduction to Matplotlib
- Installing Matplotlib
- Importing Matplotlib
- Creating a Simple Plot using Matplotlib
- Working with Figures and Axes
- Important Matplotlib Plots in Machine Learning
- How is Matplotlib used in Machine Learning
- Endnotes
Introduction to Matplotlib
Matplotlib is an open-source plotting library that is used to create static 2D plots, although it does have some support for 3D visualizations as well.
It is a comprehensive library that makes producing both simple and advanced plots straightforward and intuitive.
It has applications in Python scripts, Jupyter notebook, and web application servers.
Best-suited Python for data science courses for you
Learn Python for data science with these high-rated online courses
Installing Matplotlib
Let’s start with installing the library in your working environment first:
#Windows, Linus, MacOS users: python -mpip install -U matplotlib #To install Matplotlib in Jupyter Notebook:pip install matplotlib #To install Matplotlib in Anaconda Prompt:conda install matplotlib
Importing Matplotlib
Now let’s import the Matplotlib library along with the other libraries we might need today:
import pandas as pdimport numpy as np from matplotlib import pyplot as plt#orimport matplotlib.pyplot as plt %matplotlib inline
In Matplotlib, pyplot is used to create figures and change their characteristics.
The %matplotlib inline function allows for plots to be visible when using Jupyter Notebook.
Creating a Simple Plot using Matplotlib
A line plot is the most basic plot to create. Simply use plt.plot() as shown below:
#Exampleplt.plot([32,26,43,41,16,37,29])plt.ylabel('marks out of 50')plt.show()
If we provide a single list or array to the plot() function, Matplotlib assumes it is a sequence of y values, and automatically generates x values for you. To plot for x versus y, we can write the following command:
#Exampleplt.plot([1,2,3,4,5,6,7], [32,26,43,41,16,37,29], 'g^')plt.xlabel('roll no.')plt.ylabel('marks out of 50')plt.show()
Do you see how we changed the type of plot above? For every x, y pair of arguments, there is an optional third argument which is the format string that indicates the color (g for green) and line type (^ for triangles) of the plot.
The default format string is ‘b-‘, which is a solid blue line.
Working with Figures and Axes
Figure Object
The Figure object should be considered as your frame. It is the bounded space within which one or more graphs can be plotted.
plt.figure() is used to create the empty Figure object in Matplotlib. It has the following additional parameters:
- figsize: Figure dimension (width, height) in inches
- dpi: Dots per inch
- facecolor: Figure patch facecolor
- edgecolor: Figure patch edge color
- linewidth: Linewidth of the frame
Axes Object
A figure usually contains multiple axes (plots). The Axes object is the canvas on which you plot your graphs. Each Axes has a title, an X –label, and a Y –label.
- add_axes() to add axes to the figure
- ax.set_title() for setting title
- ax.set_xlabel() and ax.set_ylabel() for setting x and y-label respectively
#creating figfig=plt.figure(figsize=[7, 5], facecolor='pink', edgecolor='b') #adding axes to figax = fig.add_axes([0,0,1,1])ax.set_title("New Figure and Axes")ax.set_xlabel('x-axis')ax.set_ylabel('y-axis')
Let’s add the line plot we created in the above example to our ‘New Figure and Axes’:
#Examplefig=plt.figure(figsize=[7, 5], facecolor='pink', edgecolor='b') ax = fig.add_axes([0,0,1,1])ax.set_title("Example Line Plot")ax.set_xlabel('roll no.')ax.set_ylabel('marks out of 50') plt.plot([1,2,3,4,5,6,7],[32,26,43,41,16,37,29])plt.show()
As discussed, let’s see how we can add multiple plots in a single figure:
Subplots
We use pyplot.subplots to create a figure and a grid of subplots with a single call. The subplots() function returns a Figure object and an Axes object.
Let’s add another plot to the above example.
#Examplefig, (ax1, ax2) = plt.subplots(1, 2, figsize=[12,5], facecolor='pink', edgecolor='b')fig.suptitle('Example Subplots') ax1.plot([1,2,3,4,5,6,7],[32,26,43,41,16,37,29])ax1.set_title("Class 1")ax1.set_xlabel('roll no.')ax1.set_ylabel('marks out of 50') ax2.plot([1,2,3,4,5,6,7],[48,32,40,44,36,27,21], color='g')ax2.set_title("Class 2")ax2.set_xlabel('roll no.')ax2.set_ylabel('marks out of 50') plt.show()
You can learn more about Matplotlib Subplots here.
Important Matplotlib Plots in Machine Learning
Matplotlib provides a wide variety of plot formats to support various visualizations methods. The most popular ones are linked here:
- Line Plots
- Pie Charts
- Histograms
- Scatter Plots
- Box Plots
Use of Matplotlib in Machine Learning
As discussed above, the Matplotlib library is used during the Exploratory Data Analysis (EDA) and Data Visualization phases of an ML model building process.
Let’s understand how EDA is done using Matplotlib with an example of Harmonized System. It was developed by the WCO (World Customs Organization) as a multipurpose international product nomenclature that describes the type of commodities imported or exported each year. This system is used by 200+ countries. It comprises about 5,000 commodity groups; each identified by a six-digit code (HS Code). We will make use of two datasets that contain records for import and export products. You can find them here. [hyperlink datasets]
Step 1 – Import the required libraries
import pandas as pdimport numpy as npimport matplotlib.pyplot as plt%matplotlib inlineimport warningswarnings.filterwarnings("ignore")
Step 2 – Load the datasets
#Read the data from the uploaded csv filesdata_export = pd.read_csv('export.csv')data_import = pd.read_csv('import.csv')
You can concatenate the two datasets as shown below:
#Concatenate the datadata_export['cat'] = 'E'data_import['cat'] = 'I' df = pd.concat([data_export,data_import],ignore_index=True)df
Now, let’s proceed with analyzing this data.
Step 3 – Perform EDA
Which are the top ten countries where the value of export is highest?
You need to find out the top ten export destinations of India:
df1 = data_export.groupby('country').agg({'value':'sum'})df1 = df1.sort_values(by='value', ascending = False)df1 = df1[:10]df1
Let’s plot a bar graph for the same. We use the bar() to plot the bar graph, which is vertical by default. Use barh() to plot a horizontal bar graph:
#Plotting a horizontal bar plotfig = plt.figure(figsize = (10, 5)) plt.barh(df1.index, df1.value)plt.xlabel("Value")plt.ylabel("Country")plt.title("Country-wise Export")plt.show()
Find the trend in the trade deficit for India.
The trade deficit is the amount by which the cost of a country’s imports exceeds the value of its exports.
Let’s plot a line chart and compare the trend of total import and export values for each year from India. This will give us a fair idea about the trade deficit:
#Plotting a simple line plotfig = plt.figure(figsize = (10, 5)) df[df['cat'] == 'I'].groupby(['year'])['value'].sum().plot(c='orange');df[df['cat'] == 'E'].groupby(['year'])['value'].sum().plot(c='purple');
Visualize the max export values to the UK for each year.
Export value of ‘ZINC AND ARTICLES THEREOF’ has been maximum to UK in any year:
df[(df['country'] == 'U K') & (df['cat'] == 'E')].groupby(['year']).max()[ ['value', 'Commodity']]
Let’s plot a bar graph for the max export values to UK for each year:
#Plotting a vertical bar graphfig = plt.figure(figsize = (10, 5)) df[(df['country'] == 'U K') & (df['cat'] == 'E')].groupby(['year']).max() ['value'].plot(x='Year', y='Maximum value',kind="bar");
Compare the means of import/export values of expensive commodities for each year.
Let’s say that if the value exceeds 1000, it is an expensive trade. We’ll plot a boxplot for this one:
thresh = 1000df2 = df[(df.value > thresh)] #Plotting a box plotdf2.boxplot(column = 'value', by = 'HSCode', figsize=(12,7))
Analyze the import values of commodities imported from Canada.
Plot the total values of imports each year using a pie chart:
x = pd.DataFrame(df[(df.country == 'CANADA') & (df.cat == 'I')].groupby( ['year']).sum()['value']).reset_index() #Plotting a pie chartplt.pie(x['value'],labels=x['year']);plt.axis('equal')plt.tight_layout()plt.show()
We can make the pie chart more readable by highlighting the maximum import value:
\n \n <pre class="python" style="font-family:monospace">\n \n <span style="color: #808080;font-style: italic">\n \n #Explode the slice with the highest imports\n \n fig \n \n <span style="color: #66cc66">\n \n = plt.\n \n <span style="color: black">\n \n figure\n \n <span style="color: black">\n (figsize \n <span style="color: #66cc66">\n = \n <span style="color: black">\n ( \n <span style="color: #ff4500">\n 10 \n <span style="color: #66cc66">\n , \n <span style="color: #ff4500">\n 5 \n <span style="color: black">\n ) \n <span style="color: black">\n ) \n \n plt. \n <span style="color: black">\n pie \n <span style="color: black">\n (x \n <span style="color: black">\n [ \n <span style="color: #483d8b">\n 'value' \n <span style="color: black">\n ] \n <span style="color: #66cc66">\n ,labels \n <span style="color: #66cc66">\n =x \n <span style="color: black">\n [ \n <span style="color: #483d8b">\n 'year' \n <span style="color: black">\n ] \n <span style="color: #66cc66">\n ,explode \n <span style="color: #66cc66">\n = \n <span style="color: black">\n ( \n <span style="color: #ff4500">\n 0 \n <span style="color: #66cc66">\n , \n <span style="color: #ff4500">\n 0.15 \n <span style="color: #66cc66">\n , \n <span style="color: #ff4500">\n 0 \n <span style="color: #66cc66">\n , \n <span style="color: #ff4500">\n 0 \n <span style="color: #66cc66">\n , \n <span style="color: #ff4500">\n 0 \n <span style="color: #66cc66">\n , \n <span style="color: #ff4500">\n 0 \n <span style="color: #66cc66">\n , \n <span style="color: #ff4500">\n 0 \n <span style="color: #66cc66">\n , \n <span style="color: #ff4500">\n 0 \n <span style="color: #66cc66">\n , \n <span style="color: #ff4500">\n 0 \n <span style="color: black">\n ) \n <span style="color: #66cc66">\n ,startangle \n <span style="color: #66cc66">\n = \n <span style="color: #ff4500">\n 90 \n <span style="color: #66cc66">\n ,autopct \n <span style="color: #66cc66">\n = \n <span style="color: #483d8b">\n '%1.1f%%' \n <span style="color: black">\n ) \n <span style="color: #66cc66">\n ; \n plt. \n <span style="color: black">\n axis \n <span style="color: black">\n ( \n <span style="color: #483d8b">\n 'equal' \n <span style="color: black">\n ) \n plt. \n <span style="color: black">\n title \n <span style="color: black">\n ( \n <span style="color: #483d8b">\n 'Imports by Canada' \n <span style="color: black">\n ) \n plt. \n <span style="color: black">\n tight_layout \n <span style="color: black">\n ( \n <span style="color: black">\n ) \n plt. \n <span style="color: black">\n show \n <span style="color: black">\n ( \n <span style="color: black">\n ) \n </span style="color: black"> \n </span style="color: black"> \n </span style="color: black"> \n </span style="color: black"> \n </span style="color: black"> \n </span style="color: black"> \n </span style="color: black"> \n </span style="color: #483d8b"> \n </span style="color: black"> \n </span style="color: black"> \n </span style="color: black"> \n </span style="color: #483d8b"> \n </span style="color: black"> \n </span style="color: black"> \n </span style="color: #66cc66"> \n </span style="color: black"> \n </span style="color: #483d8b"> \n </span style="color: #66cc66"> \n </span style="color: #66cc66"> \n </span style="color: #ff4500"> \n </span style="color: #66cc66"> \n </span style="color: #66cc66"> \n </span style="color: black"> \n </span style="color: #ff4500"> \n </span style="color: #66cc66"> \n </span style="color: #ff4500"> \n </span style="color: #66cc66"> \n </span style="color: #ff4500"> \n </span style="color: #66cc66"> \n </span style="color: #ff4500"> \n </span style="color: #66cc66"> \n </span style="color: #ff4500"> \n </span style="color: #66cc66"> \n </span style="color: #ff4500"> \n </span style="color: #66cc66"> \n </span style="color: #ff4500"> \n </span style="color: #66cc66"> \n </span style="color: #ff4500"> \n </span style="color: #66cc66"> \n </span style="color: #ff4500"> \n </span style="color: black"> \n </span style="color: #66cc66"> \n </span style="color: #66cc66"> \n </span style="color: black"> \n </span style="color: #483d8b"> \n </span style="color: black"> \n </span style="color: #66cc66"> \n </span style="color: #66cc66"> \n </span style="color: black"> \n </span style="color: #483d8b"> \n </span style="color: black"> \n </span style="color: black"> \n </span style="color: black"> \n </span style="color: black"> \n </span style="color: black"> \n </span style="color: #ff4500"> \n </span style="color: #66cc66"> \n </span style="color: #ff4500"> \n </span style="color: black"> \n </span style="color: #66cc66"> \n </span style="color: black">\n \n </span style="color: black">\n \n </span style="color: #66cc66">\n \n </span style="color: #808080;font-style: italic">\n \n </pre class="python" style="font-family:monospace">
So, you see how easy Matplotlib makes it for us to visualize and analyze data? Once, you have found relevant patterns in your data you can go ahead with model development using Machine Learning algorithms.
Endnotes
Matplotlib is one of the oldest Python data visualization libraries, and thanks to its wealth of features and ease of use it is still one of the most widely used ones. Matplotlib was first released back in 2003 and has been continuously updated since. I hope this article helped you understand the concept of Matplotlib in Machine Learning.
Top Trending Articles:
Data Analyst Interview Questions | Data Science Interview Questions | Machine Learning Applications | Big Data vs Machine Learning | Data Scientist vs Data Analyst | How to Become a Data Analyst | Data Science vs. Big Data vs. Data Analytics | What is Data Science | What is a Data Scientist | What is Data Analyst
This is a collection of insightful articles from domain experts in the fields of Cloud Computing, DevOps, AWS, Data Science, Machine Learning, AI, and Natural Language Processing. The range of topics caters to upski... Read Full Bio