Tutorial – Matplotlib Scatter Plot – Shiksha Online
If you’re delving into machine learning and data science, you’re sure to perform Exploratory Data Analysis (EDA) to analyze the data before getting on with model development. EDA helps summarize the main characteristics of your data, mostly employing data visualization methods.
Python’s most popular visualization library – Matplotlib, provides support for many useful graphical visualizations. For this article, we are going to focus on Scatter Plots – a common technique used to observe relationships between variables in your data. Let’s see how to perform EDA with scatter plots in Python using Matplotlib.
We will be covering the following sections:
- Quick Intro to Scatter Plots
- Installing and Importing Matplotlib
- Creating a Matplotlib Scatter Plot
- Adding Elements to the Scatter Plot
- Adding a Regression Line to the Scatter Plot
- Parameters of the Scatter Plot
- Adding a Category to the Scatter Plot
- Saving your Scatter Plot
- Limitations of using Scatter Plots
Quick Intro to Scatter Plots
A scatter plot is used to visualize the relationship between two numerical variables. The values of the variables are represented by dots using Cartesian coordinates. The positioning of the dots, called markers, allows us to infer if there is a correlation between the variables and the strength of this correlation.
What is meant by correlation?
- Correlation is a dimensionless unit that determines the degree to which variables are related.
- It measures both strength and direction of the linear relationship between variables.
- Its value lies between 0 and 1, depicting strength.
- + and – signs depict direction.
What is meant by the strength of a correlation?
- Perfect Linear Correlation:
A perfect positive correlation is given the value of 1, and a perfect negative correlation is given the value of -1.
- Strong Linear Correlation:
The closer the number is to 1 or -1, the stronger the correlation between the two variables.
- Weak Linear Correlation:
The closer the number is to 0, the weaker the correlation.
- No Correlation:
If there is absolutely no correlation between the two variables, the value given is 0.
Now, we will understand how to create a scatter plot to determine the correlation between two variables. The dataset used in this blog can be found here. It contains information on BMI, sex, the number of children of customers, their insurance costs, and whether or not an individual is a smoker. We need to ascertain if there is a relationship between a customer’s BMI and insurance charges.
Let’s get started!
Best-suited Python for data science courses for you
Learn Python for data science with these high-rated online courses
Installing and Importing Matplotlib
First, let’s install the Matplotlib library in your working environment. Execute the following command in your terminal:
pip install matplotlib
Now let’s import the libraries we’re going to need today:
import pandas as pdimport matplotlib.pyplot as plt%matplotlib inline
In Matplotlib, pyplot is used to create figures and change their characteristics.
The %matplotlib inline function allows for plots to be visible when using Jupyter Notebook.
Creating a Matplotlib Scatter Plot
Load the dataset
Prior to creating our graphs, let’s check out the dataset:
#Read the datasetdf = pd.read_csv('insurance.csv')df.head()
#Check out the number of columnsdf.shape
There are 7 columns (or features) in this dataset. Let’s print them all:
#List all column namesprint(df.columns)
Based on our requirement, our focus is going to be on the bmi and charges columns from the dataset.
Plotting the data
Now, let’s plot a scatter plot using the plt.scatter() function:
plt.scatter(x=df['bmi'], y=df['charges'])
Although through the above plot, we can get a general idea that a correlation exists between the two variables, we cannot extract any relevant information from the scatter plot just yet.
Let’s add a few elements here to help us interpret the visualization in a better way.
Adding Elements to the Scatter Plot
The plot we have created would not be easily understandable to a third pair of eyes without context, so let’s try to add different elements to make it more readable:
- Use plt.title() for setting a plot title
- Use plt.xlabel() and plt.ylabel() for labeling x and y-axis respectively
- Use plt.show() for displaying the plot
x = df['bmi']y = df['charges'] #Add elements to scatterplotplt.scatter(x, y)plt.ylabel('Charges')plt.xlabel('BMI')plt.title('BMI vs Insurance Charges')plt.show()
Through the above plot, we can establish that the BMI and insurance charges are positively correlated, meaning the customers with higher body mass index also tend to incur higher insurance costs. This pattern does make sense, given that a higher BMI is typically associated with a higher risk of chronic ailments.
Adding a Regression Line to the Scatter Plot
The scatter plot is a close cousin of the line plot. Instead of the points being joined by a line, here they are represented by scattered dots. Just like in a line plot, we can add a regression line, a line that best fits the data, through our scatter plot to double-check the strength of this relationship.
Can you recall the linear equation formula from your elementary maths class? Let’s jog your memory:
We will use the np.polyfit() function to plot a regression line along the scatter plot:
plt.scatter(x, y)plt.ylabel('Charges')plt.xlabel('BMI')plt.title('BMI vs Insurance Charges') m, b = np.polyfit(x, y, 1)plt.plot(x, m*x+b)plt.show()
The linear line across the plot shows the trend of the charges based on the BMI of the customers. Thus, we can re-establish that there is a low positive correlation between the two variables.
You can also change the color of the regression line:
plt.scatter(x, y)plt.ylabel('Charges')plt.xlabel('BMI')plt.title('BMI vs Insurance Charges') m, b = np.polyfit(x, y, 1)plt.plot(x, m*x+b, c='r')plt.show()
Parameters of the Matplotlib Scatter Plot
- Use parameter c to specify the marker colors:
plt.scatter(x, y, c='y')plt.ylabel('Charges')plt.xlabel('BMI')plt.title('BMI vs Insurance Charges')plt.show()
- Use parameter marker to specify marker types and style:
(‘o’, ‘v’, ‘^’, ‘<‘, ‘>’, ‘8’, ‘s’, ‘p’, ‘*’, ‘h’, ‘H’, ‘D’, ‘d’, ‘P’, ‘X’)
plt.scatter(x, y, marker='v')plt.ylabel('Charges')plt.xlabel('BMI')plt.title('BMI vs Insurance Charges')plt.show()
- Use edgecolor parameter to highlight the marker edges with the specified color:
plt.scatter(x, y, marker='v', edgecolors='r')plt.ylabel('Charges')plt.xlabel('BMI')plt.title('BMI vs Insurance Charges')plt.show()
- Use alpha parameter to specify the transparency of markers. It takes an integer between 0 and 1:
plt.scatter(x, y, marker='v', alpha=0.3)plt.ylabel('Charges')plt.xlabel('BMI')plt.title('BMI vs Insurance Charges')plt.show()
Adding a Category to the Scatter Plot
You can obtain even more information from a scatter plot by analyzing the relationship of the two numerical variables with a categorical variable.
For example, let’s say you want to find out if smoking affects the correlation between BMI and insurance charges. We will use the parameter c, which allows us to color-code the data points. We can define a dictionary and map it with the data points as shown below:
colors = {'yes':'red', 'no':'green'}plt.scatter(x, y, c=df['smoker'].map(colors))
plt.ylabel('Charges')plt.xlabel('BMI')plt.title('BMI vs Insurance Charges')plt.show()
Color by category using Pandas GroupBy
We can also create a plot by grouping our DataFrame based on the smoker column and manually assigning colors based on the groups.
fig, ax = plt.subplots() colors = {'yes':'red', 'no':'green'} grouped = df.groupby('smoker')for key, group in grouped: group.plot(ax=ax, kind='scatter', x='bmi', y='charges', label=key, color=colors[key]) plt.ylabel('Charges')plt.xlabel('BMI')plt.title('BMI vs Insurance Charges for Smokers/Non-smokers') ax.grid(True) plt.show()
Did you notice that we’ve added a grid to the above plot specifying the grid() function to be true?
From the given scatter plot, we can conclude that customers who smoke pay higher insurance charges as compared to non-smokers for the same BMI.
Saving your Matplotlib Scatter Plot
You can save your plot as an image using the savefig() function. Plots can be saved in – .png, .jpeg, .pdf, and many other supporting formats.
Let’s try saving the ‘Income Distribution of Departments’ plot we have created above:
colors = {'yes':'red', 'no':'green'} plt.scatter(x, y, c=df['smoker'].map(colors))plt.ylabel('Charges')plt.xlabel('BMI')plt.title('BMI vs Insurance Charges') fig.savefig('scatterplot.png')
The image would have been saved with the filename ‘scatterplot.png’.
To view the saved image, we’ll use the matplotlib.image module, as shown below:
#Displaying the saved imageimport matplotlib.image as mpimg image = mpimg.imread("scatterplot.png")plt.imshow(image)plt.show()
Limitations of using Matplotlib Scatter Plot
- Overplotting
Overplotting can happen when we have a lot of data points to plot. Overplotting occurs when data points overlap to an extent such that it is difficult to see relationships between the variables.
A way to alleviate this issue is to consider a subset of data points to give a general idea of the patterns in the entire dataset. Other ways are:
- Adding transparency to markers to allow for overlaps to be visible
- Reducing marker size so that fewer overlaps occur
But in such cases, it is almost always best to use a different chart type, such as a heat map.
- Wrongly interpreting correlation and causation
Causation means that changes in one variable cause change in the other. For example, there is causation between junk food consumption and obesity. So, causality implies correlation.
However, there is also a common phrase in statistics, ‘Correlation does not imply Causation. The observed relationship may be caused by some third variable that affects both plotted variables and that they do not directly affect each other in any way. So, if a direct causal link cannot be established, then further analysis is required to account for other potential variables’ effects.
Endnotes
A Scatter Plot is one of the most commonly used chart types for visualizing data. It is also known as a scattergram, scatter graph, or scatter chart. Matplotlib is one of the oldest Python visualization libraries and provides a wide variety of charts and plots for better analysis, including Matplotlib Scatter Plot.
Top Trending Articles:
Data Analyst Interview Questions | Data Science Interview Questions | Machine Learning Applications | Big Data vs Machine Learning | Data Scientist vs Data Analyst | How to Become a Data Analyst | Data Science vs. Big Data vs. Data Analytics | What is Data Science | What is a Data Scientist | What is Data Analyst
This is a collection of insightful articles from domain experts in the fields of Cloud Computing, DevOps, AWS, Data Science, Machine Learning, AI, and Natural Language Processing. The range of topics caters to upski... Read Full Bio