Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Grouping Observations | Scatter Plots
Visualization in Python with matplotlib

book
Grouping Observations

One more common usage of scatter plots is cluster analysis - finding if there is some relation between groups of observations and if we can statistically divide the observations into groups.

As we mentioned earlier, the c parameter of the .scatter function can be the array of colors with the same size as points data. This means we can assign to each observation a respective color (or size). In this chapter, we will assign a color in accordance with a region country is located in.

Let's expand the previous example. Assume we want to split the data into categories based on values of the 'life exp' column (life expectancy). The minimum value in this column is a little over 50, the maximum one is a little less 85. Using the .cut() method of pandas we can create a new column in the data dataframe and then point each point according to the group. To perform that, we need to use the for loop to iterate over each of the groups. Also we need to pre-define list of colors. In each step, we need to call the .scatter() function.

# Import the libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Reading the data
data = pd.read_csv('https://codefinity-content-media.s3.eu-west-1.amazonaws.com/ed80401e-2684-4bc4-a077-99d13a386ac7/gapminder2017.csv', index_col = 0)

# Create Figure and Axes objects
fig, ax = plt.subplots()

# Split data into categories
data['life exp group'] = pd.cut(data['life exp'], [50, 55, 60, 65, 70, 75, 80, 85],
labels = ['50-55', '55-60', '60-65', '65-70', '70-75', '75-80', '80-85'])
colors = ['red', 'green', 'blue', 'purple', 'yellow', 'pink', 'black']

# Initialize for loop
for group, color in zip(data['life exp group'].unique(), colors):
temp_data = data.loc[data['life exp group'] == group]
123456789101112131415161718192021222324
# Import the libraries import pandas as pd import numpy as np import matplotlib.pyplot as plt # Reading the data data = pd.read_csv('https://codefinity-content-media.s3.eu-west-1.amazonaws.com/ed80401e-2684-4bc4-a077-99d13a386ac7/gapminder2017.csv', index_col = 0) # Create Figure and Axes objects fig, ax = plt.subplots() # Split data into categories data['life exp group'] = pd.cut(data['life exp'], [50, 55, 60, 65, 70, 75, 80, 85], labels = ['50-55', '55-60', '60-65', '65-70', '70-75', '75-80', '80-85']) colors = ['red', 'green', 'blue', 'purple', 'yellow', 'pink', 'black'] # Initialize for loop for group, color in zip(data['life exp group'].unique(), colors): temp_data = data.loc[data['life exp group'] == group] ax.scatter(temp_data['gdp per capita'], temp_data['internet users'], c = color, label = group) # Display the legend and plot plt.legend() plt.show()
copy

There we used the .cut() method of pandas with 3 arguments:

  • the first is the data we want to split into groups;
  • the second is bins limits;
  • the labels is name of groups (if no value is set, then integers starting from 1 will be used).

Also, we used the zip() function. This function is convenient when, for example, we want to iterate over two lists with the same lengths. Since we iterated over two lists, two dummy variables were used. Within each of the .scatter() functions we set the label parameter so that we will be able to display the legend (don't forget about the plt.legend()).

¿Todo estuvo claro?

¿Cómo podemos mejorarlo?

¡Gracias por tus comentarios!

Sección 3. Capítulo 7
some-alt