Subplots
The subplots()
function from pyplot
serves this specific purpose. It was previously used to create a canvas in the first section, and now it will be examined in more detail.
Rows and Columns
The most important arguments of this function are nrows
and ncolumns
, which set the number of rows and columns in the subplot grid. By default, both are 1, so you get a single plot (Axes
).
The function subplots()
returns a Figure
object and either one Axes
object or an array of Axes
objects, depending on the grid size.
import matplotlib.pyplot as plt fig, axs = plt.subplots(2, 2) plt.show()
A 2 by 2 subplot grid was created.
Since there are multiple subplots, subplots
returns an array of Axes
objects, usually stored in a variable called axs
(singular ax
is for one plot).
In this case, axs
is a two-dimensional array, so you need both a row and a column index to access a specific subplot.
import matplotlib.pyplot as plt import numpy as np data_linear = np.arange(1, 11) data_squared = data_linear ** 2 # Creating a 2x2 subplot grid fig, axs = plt.subplots(2, 2) # Creating a different plot for each Axes object axs[0, 0].plot(data_linear) axs[0, 1].plot(data_squared) axs[1, 0].scatter(data_linear, data_linear) axs[1, 1].scatter(data_linear, data_squared) plt.show()
The first row (row 0) has two line plots, and the second row (row 1) has two scatter plots.
Since each plot is placed on a separate subplot, plt.plot()
or plt.scatter()
cannot be used directly. The appropriate approach is to call the corresponding method on each individual Axes
object.
Converting to 1D Array
It is also possible to use the .ravel()
method to convert 2D Axes
array to 1D contiguous flattened array:
import matplotlib.pyplot as plt import numpy as np data_linear = np.arange(1, 11) data_squared = data_linear ** 2 # Creating a 2x2 subplot grid fig, axs = plt.subplots(2, 2) # Flattening axs to a 1D array for easier indexing axs = axs.ravel() # Creating a different plot for each Axes object axs[0].plot(data_linear) axs[1].plot(data_squared) axs[2].scatter(data_linear, data_linear) axs[3].scatter(data_linear, data_squared) plt.show()
With a 2x2 array, axs.ravel()
converts it into a 1D array containing four elements.
Sharing an Axis
The subplots()
function also has sharex
and sharey
parameters. They control whether the x or y axes are shared across subplots. Both are set to False
by default.
import matplotlib.pyplot as plt import numpy as np data_linear = np.arange(1, 11) data_squared = data_linear ** 2 # Create a 2x2 subplot grid with shared x-axis across all subplots fig, axs = plt.subplots(2, 2, sharex=True) # Flatten axs array for easier indexing axs = axs.ravel() # Plotting different data on each subplot axs[0].plot(data_linear) axs[1].plot(data_squared) axs[2].scatter(data_linear, data_linear) axs[3].scatter(data_linear, data_squared) plt.show()
Setting sharex=True
shares the x-axis across all subplots, which is useful here because all subplots use the same x-axis values.
You can also set sharex
or sharey
to 'row'
to share the axis within each subplot row, or 'col'
to share it within each subplot column.
As usual feel free to explore more in the
subplots()
documentation in case you want to.
Swipe to start coding
- Use the correct function to create a subplot grid.
- The grid should have 3 rows and 1 column (specify the first two parameters).
- Specify the rightmost keyword argument, so that x-axis will be shared among all the subplots.
- Store the result of the function for creating subplots in the
fig
andaxs
variables (from left to right). - Place the first line plot for
data_linear
on the first row (row0
) of the subplot grid. - Place the second line plot for
data_squared
on the second row (row1
) of the subplot grid. - Place the third line plot for
data_exp
on the third row (row2
) of the subplot grid.
Solution
Thanks for your feedback!