Course Content
Visualization in Python with matplotlib
Visualization in Python with matplotlib
Grouping Observations
One more common usage of scatter plots is cluster analysis - finding if there is some relation between groups of observations and if we can statistically divide the observations into groups.
As we mentioned earlier, the c
parameter of the .scatter
function can be the array of colors with the same size as points data. This means we can assign to each observation a respective color (or size). In this chapter, we will assign a color in accordance with a region country is located in.
Let's expand the previous example. Assume we want to split the data into categories based on values of the 'life exp'
column (life expectancy). The minimum value in this column is a little over 50, the maximum one is a little less 85. Using the .cut()
method of pandas
we can create a new column in the data
dataframe and then point each point according to the group. To perform that, we need to use the for
loop to iterate over each of the groups. Also we need to pre-define list of colors. In each step, we need to call the .scatter()
function.
# Import the libraries import pandas as pd import numpy as np import matplotlib.pyplot as plt # Reading the data data = pd.read_csv('https://codefinity-content-media.s3.eu-west-1.amazonaws.com/ed80401e-2684-4bc4-a077-99d13a386ac7/gapminder2017.csv', index_col = 0) # Create Figure and Axes objects fig, ax = plt.subplots() # Split data into categories data['life exp group'] = pd.cut(data['life exp'], [50, 55, 60, 65, 70, 75, 80, 85], labels = ['50-55', '55-60', '60-65', '65-70', '70-75', '75-80', '80-85']) colors = ['red', 'green', 'blue', 'purple', 'yellow', 'pink', 'black'] # Initialize for loop for group, color in zip(data['life exp group'].unique(), colors): temp_data = data.loc[data['life exp group'] == group] ax.scatter(temp_data['gdp per capita'], temp_data['internet users'], c = color, label = group) # Display the legend and plot plt.legend() plt.show()
There we used the .cut()
method of pandas
with 3 arguments:
- the first is the data we want to split into groups;
- the second is bins limits;
- the
labels
is name of groups (if no value is set, then integers starting from 1 will be used).
Also, we used the zip()
function. This function is convenient when, for example, we want to iterate over two lists with the same lengths. Since we iterated over two lists, two dummy variables were used.
Within each of the .scatter()
functions we set the label
parameter so that we will be able to display the legend (don't forget about the plt.legend()
).
Thanks for your feedback!