6

I have a figure consisting of 3 subplots. I would like to locate the last subplot in the middle of the second row. Currently it is located in the left bottom of the figure. How do I do this? I cannot find the answer on stack overflow.

    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(13,10))
    ax= axes.flatten()
    ax[0].plot(vDT, np.cumsum(mWopt0[asset0,:])*percentage/iTT, label= 'COAL, c = 0')
    ax[0].legend()
    ax[0].set_title('Proportion in most invested stock')
    ax[1].plot(vDT, np.cumsum(mWopt01[asset01,:])*percentage/iTT, label= 'OINL, c = 0.1')
    ax[1].plot(vDT, np.cumsum(mWopt03[asset03,:])*percentage/iTT, label= 'OINL, c = 0.3')
    ax[1].plot(vDT, np.cumsum(mWopt05[asset05,:])*percentage/iTT, label= 'OINL, c = 0.5')
    ax[1].plot(vDT, np.cumsum(mWopt2[asset2,:])*percentage/iTT, label= 'OINL, c = 2')
    ax[1].plot(vDT, np.cumsum(mWopt5[asset5,:])*percentage/iTT, label= 'OINL, c = 5')
    ax[1].plot(vDT, np.cumsum(mWopt10[asset10,:])*percentage/iTT, label= 'OINL, c = 10')
    ax[1].legend()
    ax[1].set_title('Proportion in most invested stock')
    ax[2].plot(vDT, np.cumsum(mWopt01[index,:])*percentage/iTT, label= 'c = 0')
    ax[2].plot(vDT, np.cumsum(mWopt01[index,:])*percentage/iTT, label= 'c = 0.1')
    ax[2].plot(vDT, np.cumsum(mWopt03[ index,:])*percentage/iTT, label= 'c = 0.3')
    ax[2].plot(vDT, np.cumsum(mWopt05[index,:])*percentage/iTT, label= 'c = 0.5')
    ax[2].plot(vDT, np.cumsum(mWopt2[index,:])*percentage/iTT, label= 'c = 2')
    ax[2].plot(vDT, np.cumsum(mWopt5[index,:])*percentage/iTT, label= 'c = 5')
    ax[2].plot(vDT, np.cumsum(mWopt10[index,:])*percentage/iTT, label= 'c = 10')
    ax[2].legend()
    ax[2].set_title('Proportion invested in index')
    ax[0].set_ylabel('Expanding window weight')
    ax[1].set_ylabel('Expanding window weight')
    ax[2].set_ylabel('Expanding window weight')
    ax[3].remove()
    fig.autofmt_xdate(bottom=0.2, rotation=75, ha='right')
    plt.savefig('NSE_por_unrestricted_mostweightSI.jpg', bbox_inches='tight')
    plt.show()
user9891079
  • 119
  • 2
  • 10

2 Answers2

7

matplotlib.gridspec.Gridspec solves your problem, and can be passed to plt.subplot. In this answer, you can see that a 4x4 grid can be used to position a plot in the middle easily:

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

gs = gridspec.GridSpec(4, 4)

ax1 = plt.subplot(gs[:2, :2])
ax1.plot(range(0,10), range(0,10))

ax2 = plt.subplot(gs[:2, 2:])
ax2.plot(range(0,10), range(0,10))

ax3 = plt.subplot(gs[2:4, 1:3])
ax3.plot(range(0,10), range(0,10))

plt.show()

You can check out the demos for gridspec here: https://matplotlib.org/tutorials/intermediate/gridspec.html#sphx-glr-tutorials-intermediate-gridspec-py

The only problem is that you are using the fig, axes = pattern, which I don't see being typically used with Gridspec. You would need to refactor a bit.

Charles Landau
  • 4,187
  • 1
  • 8
  • 24
1

I created a snippet of code (gist) to center the subplots of the last row, for an arbitrary number of subplots:

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec


def subplots_centered(nrows, ncols, figsize, nfigs):
    """
    Modification of matplotlib plt.subplots(),
    useful when some subplots are empty.
    
    It returns a grid where the plots
    in the **last** row are centered.
    
    Inputs
    ------
        nrows, ncols, figsize: same as plt.subplots()
        nfigs: real number of figures
    """
    assert nfigs < nrows * ncols, "No empty subplots, use normal plt.subplots() instead"
    
    fig = plt.figure(figsize=figsize)
    axs = []
    
    m = nfigs % ncols
    m = range(1, ncols+1)[-m]  # subdivision of columns
    gs = gridspec.GridSpec(nrows, m*ncols)

    for i in range(0, nfigs):
        row = i // ncols
        col = i % ncols

        if row == nrows-1: # center only last row
            off = int(m * (ncols - nfigs % ncols) / 2)
        else:
            off = 0

        ax = plt.subplot(gs[row, m*col + off : m*(col+1) + off])
        axs.append(ax)
        
    return fig, axs

fig, axs = subplots_centered(nrows=4, ncols=3, figsize=(10,7), nfigs=11)
plt.tight_layout()

output