Tutorial – Matplotlib Scatter Plot – Shiksha Online

Tutorial – Matplotlib Scatter Plot – Shiksha Online

7 mins read890 Views Comment
Updated on Mar 4, 2022 16:14 IST

If you’re delving into machine learning and data science, you’re sure to perform Exploratory Data Analysis (EDA) to analyze the data before getting on with model development. EDA helps summarize the main characteristics of your data, mostly employing data visualization methods.

2022_02_matplotlib-scatterplot-1-e1645512720105.jpg

Python’s most popular visualization library – Matplotlib, provides support for many useful graphical visualizations. For this article, we are going to focus on Scatter Plots – a common technique used to observe relationships between variables in your data. Let’s see how to perform EDA with scatter plots in Python using Matplotlib.

We will be covering the following sections:

Quick Intro to Scatter Plots

A scatter plot is used to visualize the relationship between two numerical variables. The values of the variables are represented by dots using Cartesian coordinates. The positioning of the dots, called markers, allows us to infer if there is a correlation between the variables and the strength of this correlation. 

What is meant by correlation?

  • Correlation is a dimensionless unit that determines the degree to which variables are related.
  • It measures both strength and direction of the linear relationship between variables.
  • Its value lies between 0 and 1, depicting strength.
  • + and – signs depict direction.

What is meant by the strength of a correlation?

  1. Perfect Linear Correlation:
Graphical user interface

Description automatically generated with low confidence

A perfect positive correlation is given the value of 1, and a perfect negative correlation is given the value of -1.

  1. Strong Linear Correlation:
Chart

Description automatically generated with medium confidence

The closer the number is to 1 or -1, the stronger the correlation between the two variables.

  1. Weak Linear Correlation:
Graphical user interface

Description automatically generated with low confidence

The closer the number is to 0, the weaker the correlation.

  1. No Correlation:
Chart, scatter chart

Description automatically generated

If there is absolutely no correlation between the two variables, the value given is 0.

Now, we will understand how to create a scatter plot to determine the correlation between two variables. The dataset used in this blog can be found here. It contains information on BMI, sex, the number of children of customers, their insurance costs, and whether or not an individual is a smoker. We need to ascertain if there is a relationship between a customer’s BMI and insurance charges.

Let’s get started!

Recommended online courses

Best-suited Python for data science courses for you

Learn Python for data science with these high-rated online courses

Free
4 weeks
12 K
8 hours
4.24 K
6 weeks
40 K
100 hours
4.99 K
– / –
– / –
– / –
– / –
60 hours
– / –
90 hours
1.27 L
12 hours

Installing and Importing Matplotlib

First, let’s install the Matplotlib library in your working environment. Execute the following command in your terminal:

 
pip install matplotlib
Copy code

Now let’s import the libraries we’re going to need today:

 
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
Copy code

In Matplotlib, pyplot is used to create figures and change their characteristics.

The %matplotlib inline function allows for plots to be visible when using Jupyter Notebook.

Creating a Matplotlib Scatter Plot

Load the dataset

Prior to creating our graphs, let’s check out the dataset:

 
#Read the dataset
df = pd.read_csv('insurance.csv')
df.head()
Copy code
Table

Description automatically generated
 
#Check out the number of columns
df.shape
Copy code

There are 7 columns (or features) in this dataset. Let’s print them all:

 
#List all column names
print(df.columns)
Copy code

Based on our requirement, our focus is going to be on the bmi and charges columns from the dataset. 

Plotting the data

Now, let’s plot a scatter plot using the plt.scatter() function:

 
plt.scatter(x=df['bmi'], y=df['charges'])
Copy code
Chart, scatter chart

Description automatically generated

Although through the above plot, we can get a general idea that a correlation exists between the two variables, we cannot extract any relevant information from the scatter plot just yet.

Let’s add a few elements here to help us interpret the visualization in a better way. 

Adding Elements to the Scatter Plot

The plot we have created would not be easily understandable to a third pair of eyes without context, so let’s try to add different elements to make it more readable:

  • Use plt.title() for setting a plot title
  • Use plt.xlabel() and plt.ylabel() for labeling x and y-axis respectively
  • Use plt.show() for displaying the plot
 
x = df['bmi']
y = df['charges']
#Add elements to scatterplot
plt.scatter(x, y)
plt.ylabel('Charges')
plt.xlabel('BMI')
plt.title('BMI vs Insurance Charges')
plt.show()
Copy code
Chart, scatter chart

Description automatically generated

Through the above plot, we can establish that the BMI and insurance charges are positively correlated, meaning the customers with higher body mass index also tend to incur higher insurance costs. This pattern does make sense, given that a higher BMI is typically associated with a higher risk of chronic ailments.

Adding a Regression Line to the Scatter Plot

The scatter plot is a close cousin of the line plot. Instead of the points being joined by a line, here they are represented by scattered dots. Just like in a line plot, we can add a regression line, a line that best fits the data, through our scatter plot to double-check the strength of this relationship.

Can you recall the linear equation formula from your elementary maths class? Let’s jog your memory:

Logo

Description automatically generated

We will use the np.polyfit() function to plot a regression line along the scatter plot:

 
plt.scatter(x, y)
plt.ylabel('Charges')
plt.xlabel('BMI')
plt.title('BMI vs Insurance Charges')
m, b = np.polyfit(x, y, 1)
plt.plot(x, m*x+b)
plt.show()
Copy code
Chart, scatter chart

Description automatically generated

The linear line across the plot shows the trend of the charges based on the BMI of the customers. Thus, we can re-establish that there is a low positive correlation between the two variables. 

You can also change the color of the regression line:

 
plt.scatter(x, y)
plt.ylabel('Charges')
plt.xlabel('BMI')
plt.title('BMI vs Insurance Charges')
m, b = np.polyfit(x, y, 1)
plt.plot(x, m*x+b, c='r')
plt.show()
Copy code
Chart, scatter chart

Description automatically generated

Parameters of the Matplotlib Scatter Plot

  • Use parameter c to specify the marker colors:
 
plt.scatter(x, y, c='y')
plt.ylabel('Charges')
plt.xlabel('BMI')
plt.title('BMI vs Insurance Charges')
plt.show()
Copy code

Chart, scatter chart

Description automatically generated

  • Use parameter marker to specify marker types and style:

 (‘o’, ‘v’, ‘^’, ‘<‘, ‘>’, ‘8’, ‘s’, ‘p’, ‘*’, ‘h’, ‘H’, ‘D’, ‘d’, ‘P’, ‘X’) 

 
plt.scatter(x, y, marker='v')
plt.ylabel('Charges')
plt.xlabel('BMI')
plt.title('BMI vs Insurance Charges')
plt.show()
Copy code
Chart, scatter chart

Description automatically generated
  • Use edgecolor parameter to highlight the marker edges with the specified color:
 
plt.scatter(x, y, marker='v', edgecolors='r')
plt.ylabel('Charges')
plt.xlabel('BMI')
plt.title('BMI vs Insurance Charges')
plt.show()
Copy code
Chart, scatter chart

Description automatically generated
  • Use alpha parameter to specify the transparency of markers. It takes an integer between 0 and 1:
 
plt.scatter(x, y, marker='v', alpha=0.3)
plt.ylabel('Charges')
plt.xlabel('BMI')
plt.title('BMI vs Insurance Charges')
plt.show()
Copy code
Chart, scatter chart

Description automatically generated

Adding a Category to the Scatter Plot

You can obtain even more information from a scatter plot by analyzing the relationship of the two numerical variables with a categorical variable.

For example, let’s say you want to find out if smoking affects the correlation between BMI and insurance charges. We will use the parameter c, which allows us to color-code the data points. We can define a dictionary and map it with the data points as shown below:

 
colors = {'yes':'red', 'no':'green'}
plt.scatter(x, y, c=df['smoker'].map(colors))
plt.ylabel('Charges')
plt.xlabel('BMI')
plt.title('BMI vs Insurance Charges')
plt.show()
Copy code
Chart, scatter chart

Description automatically generated

Color by category using Pandas GroupBy

We can also create a plot by grouping our DataFrame based on the smoker column and manually assigning colors based on the groups.

 
fig, ax = plt.subplots()
colors = {'yes':'red', 'no':'green'}
grouped = df.groupby('smoker')
for key, group in grouped:
group.plot(ax=ax, kind='scatter', x='bmi', y='charges', label=key, color=colors[key])
plt.ylabel('Charges')
plt.xlabel('BMI')
plt.title('BMI vs Insurance Charges for Smokers/Non-smokers')
ax.grid(True)
plt.show()
Copy code
Chart, scatter chart

Description automatically generated

Did you notice that we’ve added a grid to the above plot specifying the grid() function to be true?

From the given scatter plot, we can conclude that customers who smoke pay higher insurance charges as compared to non-smokers for the same BMI.

Saving your Matplotlib Scatter Plot

You can save your plot as an image using the savefig() function. Plots can be saved in – .png, .jpeg, .pdf, and many other supporting formats. 

Let’s try saving the ‘Income Distribution of Departments’ plot we have created above:

 
colors = {'yes':'red', 'no':'green'}
plt.scatter(x, y, c=df['smoker'].map(colors))
plt.ylabel('Charges')
plt.xlabel('BMI')
plt.title('BMI vs Insurance Charges')
fig.savefig('scatterplot.png')
Copy code

The image would have been saved with the filename ‘scatterplot.png’. 

To view the saved image, we’ll use the matplotlib.image module, as shown below:

 
#Displaying the saved image
import matplotlib.image as mpimg
image = mpimg.imread("scatterplot.png")
plt.imshow(image)
plt.show()
Copy code
Chart, scatter chart

Description automatically generated

Limitations of using Matplotlib Scatter Plot

  • Overplotting

Overplotting can happen when we have a lot of data points to plot. Overplotting occurs when data points overlap to an extent such that it is difficult to see relationships between the variables. 

A way to alleviate this issue is to consider a subset of data points to give a general idea of the patterns in the entire dataset. Other ways are:

  • Adding transparency to markers to allow for overlaps to be visible
  • Reducing marker size so that fewer overlaps occur

But in such cases, it is almost always best to use a different chart type, such as a heat map.

  • Wrongly interpreting correlation and causation

Causation means that changes in one variable cause change in the other. For example, there is causation between junk food consumption and obesity. So, causality implies correlation.

However, there is also a common phrase in statistics, ‘Correlation does not imply Causation. The observed relationship may be caused by some third variable that affects both plotted variables and that they do not directly affect each other in any way. So, if a direct causal link cannot be established, then further analysis is required to account for other potential variables’ effects. 

Endnotes

A Scatter Plot is one of the most commonly used chart types for visualizing data. It is also known as a scattergram, scatter graph, or scatter chart. Matplotlib is one of the oldest Python visualization libraries and provides a wide variety of charts and plots for better analysis, including Matplotlib Scatter Plot.


Top Trending Articles:

Data Analyst Interview Questions | Data Science Interview Questions | Machine Learning Applications | Big Data vs Machine Learning | Data Scientist vs Data Analyst | How to Become a Data Analyst | Data Science vs. Big Data vs. Data Analytics | What is Data Science | What is a Data Scientist | What is Data Analyst

About the Author

This is a collection of insightful articles from domain experts in the fields of Cloud Computing, DevOps, AWS, Data Science, Machine Learning, AI, and Natural Language Processing. The range of topics caters to upski... Read Full Bio