26

I'm trying to create a plotting function that takes as input the number of required plots and plots them using pylab.subplots and the sharex=True option. If the number of required plots is odd, then I would like to remove the last panel and force the tick labels on the panel right above it. I can't find a way of doing that and using the sharex=True option at the same time. The number of subplots can be quite large (>20).

Here's sample code. In this example I want to force xtick labels when i=3.

import numpy as np
import matplotlib.pylab as plt

def main():
    n = 5
    nx = 100
    x = np.arange(nx)
    if n % 2 == 0:
        f, axs = plt.subplots(n/2, 2, sharex=True)
    else:
        f, axs = plt.subplots(n/2+1, 2, sharex=True)
    for i in range(n):
        y = np.random.rand(nx)
        if i % 2 == 0:
            axs[i/2, 0].plot(x, y, '-', label='plot '+str(i+1))
            axs[i/2, 0].legend()
        else:
            axs[i/2, 1].plot(x, y, '-', label='plot '+str(i+1))
            axs[i/2, 1].legend()
    if n % 2 != 0:
        f.delaxes(axs[i/2, 1])
    f.show()


if __name__ == "__main__":
     main()
Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
LeoC
  • 912
  • 1
  • 8
  • 22

5 Answers5

48

To put it simply you make your subplots call for an even number (in this case 6 plots):

f, ax = plt.subplots(3, 2, figsize=(12, 15))

Then you delete the one you don't need:

f.delaxes(ax[2,1]) # The indexing is zero-based here

This question and response are looking at this in an automated fashion but i thought it worthwhile to post the basic use case here.

DannyMoshe
  • 6,023
  • 4
  • 31
  • 53
18

If you replace last if in your main function with this:

if n % 2 != 0:
    for l in axs[i/2-1,1].get_xaxis().get_majorticklabels():
        l.set_visible(True)
    f.delaxes(axs[i/2, 1])

f.show()

It should do the trick:

Plot

Primer
  • 10,092
  • 5
  • 43
  • 55
  • Using `delaxes` seems inefficient when dealing with a large number of axes objects (subplots). I ended up doing this with `add_subplot` instead. But I'm not sure how to acquire the `sharex` or `sharey` that's available with `subplots`. – CMCDragonkai Apr 17 '18 at 04:31
5

I generate an arbitrary number of subplots all the time (sometimes the data leads to 3 subplots, sometimes 13, etc). I wrote a little utility function to stop having to think about it.

The two functions I define are the follows. You can change the stylistic choices to match your preferences.

import math
import numpy as np
from matplotlib import pyplot as plt


def choose_subplot_dimensions(k):
    if k < 4:
        return k, 1
    elif k < 11:
        return math.ceil(k/2), 2
    else:
        # I've chosen to have a maximum of 3 columns
        return math.ceil(k/3), 3


def generate_subplots(k, row_wise=False):
    nrow, ncol = choose_subplot_dimensions(k)
    # Choose your share X and share Y parameters as you wish:
    figure, axes = plt.subplots(nrow, ncol,
                                sharex=True,
                                sharey=False)

    # Check if it's an array. If there's only one plot, it's just an Axes obj
    if not isinstance(axes, np.ndarray):
        return figure, [axes]
    else:
        # Choose the traversal you'd like: 'F' is col-wise, 'C' is row-wise
        axes = axes.flatten(order=('C' if row_wise else 'F'))

        # Delete any unused axes from the figure, so that they don't show
        # blank x- and y-axis lines
        for idx, ax in enumerate(axes[k:]):
            figure.delaxes(ax)

            # Turn ticks on for the last ax in each column, wherever it lands
            idx_to_turn_on_ticks = idx + k - ncol if row_wise else idx + k - 1
            for tk in axes[idx_to_turn_on_ticks].get_xticklabels():
                tk.set_visible(True)

        axes = axes[:k]
        return figure, axes

And here's example usage with 13 subplots:

x_variable = list(range(-5, 6))
parameters = list(range(0, 13))

figure, axes = generate_subplots(len(parameters), row_wise=True)
for parameter, ax in zip(parameters, axes):
    ax.plot(x_variable, [x**parameter for x in x_variable])
    ax.set_title(label="y=x^{}".format(parameter))

plt.tight_layout()
plt.show()

Which produces the following:

enter image description here

Or, switching to column-wise traversal order (generate_subplots(..., row_wise=False)) generates:

enter image description here

canary_in_the_data_mine
  • 2,193
  • 2
  • 24
  • 28
3

Instead of doing calculations in order to detect what subplot needs to be deleted, you can check in what subplot there is nothing printed. You can look at this answer for various methods to check if something is plotted on an axis. Using the function ax.has_Data() you can simplify your function like this:

def main():
    n = 5
    max_width = 2 ##images per row
    height, width = n//max_width +1, max_width
    fig, axs = plt.subplots(height, width, sharex=True)

    for i in range(n):
        nx = 100
        x = np.arange(nx)
        y = np.random.rand(nx)
        ax = axs.flat[i]
        ax.plot(x, y, '-', label='plot '+str(i+1))
        ax.legend(loc="upper right")

    ## access each axes object via axs.flat
    for ax in axs.flat:
        ## check if something was plotted 
        if not bool(ax.has_data()):
            fig.delaxes(ax) ## delete if nothing is plotted in the axes obj

    fig.show()

enter image description here

You can also specify how many images you want using the n parameter and how many images per row you desire with max_width parameter.

Aelius
  • 1,029
  • 11
  • 22
  • Excellent answer. Worked fine, and suitable to a dynamic environment when the number of subplots is not known. Thanks. – Wrichik Basu Oct 14 '22 at 14:16
-4

For Python 3, you can delete as below :

# I have 5 plots that i want to show in 2 rows. So I do 3 columns. That way i have 6 plots.
f, axes = plt.subplots(2, 3, figsize=(20, 10))

sns.countplot(sales_data['Gender'], order = sales_data['Gender'].value_counts().index, palette = "Set1", ax = axes[0,0])
sns.countplot(sales_data['Age'], order = sales_data['Age'].value_counts().index, palette = "Set1", ax = axes[0,1])
sns.countplot(sales_data['Occupation'], order = sales_data['Occupation'].value_counts().index, palette = "Set1", ax = axes[0,2])
sns.countplot(sales_data['City_Category'], order = sales_data['City_Category'].value_counts().index, palette = "Set1", ax = axes[1,0])
sns.countplot(sales_data['Marital_Status'], order = sales_data['Marital_Status'].value_counts().index, palette = "Set1", ax = axes[1, 1])

# This line will delete the last empty plot
f.delaxes(ax= axes[1,2]) 
Jagannath Banerjee
  • 2,081
  • 1
  • 9
  • 7
  • 1
    This answer is not useful at all: the code refers to an external module not requested (seaborn) and includes some external data (Gender, Age, etc.) to which the reader has no access. Then this is substantially a repost of the solution proposed by others without any improvement or other original work. – Aelius Nov 08 '21 at 16:33