Adding Legend
When multiple elements are present in a chart, it is often helpful to label them for clarity. The legend serves this purpose by providing a compact area that explains different components of the chart.
The following are three common ways to create a legend in matplotlib
.
First Option
Consider the following example to clarify the concept:
import matplotlib.pyplot as plt import numpy as np # Define categories and data questions = ['question_1', 'question_2', 'question_3'] yes_answers = np.array([500, 240, 726]) no_answers = np.array([432, 618, 101]) answers = np.array([yes_answers, no_answers]) # Set positions and bar width positions = np.arange(len(questions)) width = 0.3 # Create the grouped bar chart for i in range(len(answers)): plt.bar(positions + width * i, answers[i], width) # Adjust x-axis ticks to the center of groups plt.xticks(positions + width * (len(answers) - 1) / 2, questions) # Setting the labels for the legend explicitly plt.legend(['positive answers', 'negative answers']) plt.show()
In the upper left corner, a legend explains the different bars on the chart. This legend is created using the plt.legend()
function, with a list of labels passed as the first argumentβcommonly referred to as labels
.
Second Option
Another option involves specifying the label
parameter in each call of the plotting function, such as bar in our example:
import matplotlib.pyplot as plt import numpy as np # Define x-axis categories and their positions questions = ['question_1', 'question_2', 'question_3'] positions = np.arange(len(questions)) # Define answers for each category yes_answers = np.array([500, 240, 726]) no_answers = np.array([432, 618, 101]) answers = np.array([yes_answers, no_answers]) labels = ['positive answers', 'negative answers'] # Set the width for each bar width = 0.3 # Plot each category with a label for i in range(len(answers)): plt.bar(positions + width * i, answers[i], width, label=labels[i]) # Set x-axis ticks and labels at the center of each group plt.xticks(positions + width * (len(answers) - 1) / 2, questions) # Automatically create legend from label parameters plt.legend() plt.show()
Here, plt.legend()
automatically determines the elements to be added to the legend and their labels; all the elements with the label parameter specified are included.
Third Option
In fact, there is even one more option using set_label()
method on the artist (bar
in our example):
import matplotlib.pyplot as plt import numpy as np questions = ['question_1', 'question_2', 'question_3'] positions = np.arange(len(questions)) yes_answers = np.array([500, 240, 726]) no_answers = np.array([432, 618, 101]) answers = np.array([yes_answers, no_answers]) width = 0.3 labels = ['positive answers', 'negative answers'] # Plot bars for each category with labels for i in range(len(answers)): bar = plt.bar(positions + width * i, answers[i], width) bar.set_label(labels[i]) # Set x-axis ticks and labels at the center of the grouped bars center_positions = positions + width * (len(answers) - 1) / 2 plt.xticks(center_positions, questions) # Display legend above the plot, centered horizontally plt.legend(loc='upper center') plt.show()
Legend Location
There is another important keyword argument of the legend()
function, loc
, which specifies the location of the legend. Its default value is best
which "tells" the matplotlib
to automatically choose the best location for the legend to avoid overlapping with data.
import matplotlib.pyplot as plt import numpy as np questions = ['question_1', 'question_2', 'question_3'] positions = np.arange(len(questions)) yes_answers = np.array([500, 240, 726]) no_answers = np.array([432, 618, 101]) answers = np.array([yes_answers, no_answers]) width = 0.3 labels = ['positive answers', 'negative answers'] # Plot bars for each category with labels for i, label in enumerate(labels): bars = plt.bar(positions + width * i, answers[i], width) bars.set_label(label) # Set x-axis ticks and labels at the center of the grouped bars center_positions = positions + width * (len(answers) - 1) / 2 plt.xticks(center_positions, questions) # Display legend above the plot, centered horizontally plt.legend(loc='upper center') plt.show()
In this example, the legend is positioned in the upper center of the plot. Other valid values for the loc
parameter include:
'upper right'
,'upper left'
,'lower left'
;'lower right'
,'right'
;'center left'
,'center right'
,'lower center'
,'center'
.
You can explore more in legend()
documentation
Swipe to start coding
- Label the lowest bars as
'primary sector'
specifying the appropriate keyword argument. - Label the bars in the middle as
'secondary sector'
specifying the appropriate keyword argument. - Label the top bars as
'tertiary sector'
specifying the appropriate keyword argument. - Place the legend on the right side, centered vertically.
Solution
Thanks for your feedback!