4

I'm plotting on two figures and each of these figures have multiple subplots. I need to do this inside a single loop. Here is what I do when I have only one figure:

fig, ax = plt.subplots(nrows=6,ncols=6,figsize=(20, 20))
fig.subplots_adjust(hspace=.5,wspace=0.4)
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)

for x in range(1,32):
    plt.subplot(6,6,x)
    plt.title('day='+str(x))
    plt.scatter(x1,y1)
    plt.scatter(x2,y2)
    plt.colorbar().set_label('Distance from ocean',rotation=270)
plt.savefig('Plots/everyday_D color.png')    
plt.close()

Now I know when you have multiple figures you need to do something like this:

fig1, ax1 = plt.subplots()
fig2, ax2 = plt.subplots()

But I don't know how to plot in the loop, that each subplot is in it's place (Because you can't keep doing plt.scatter if there are two figures). Please be specific with what do I need to do (regarding whether it is fig1.scatter, ax1.scatter, fig.subplots_adjust, ... and how to save and close at the end)

Dr proctor
  • 177
  • 1
  • 3
  • 7

2 Answers2

4

Each of the pyplot function has its corresponding method in the object oriented API. If you really want to loop over both figures' axes at the same time, this would look like this:

import numpy as np
import matplotlib.pyplot as plt

x1 = x2 = np.arange(10)
y1 = y2 = c = np.random.rand(10,6)

fig1, axes1 = plt.subplots(nrows=2,ncols=3)
fig1.subplots_adjust(hspace=.5,wspace=0.4)

fig2, axes2 = plt.subplots(nrows=2,ncols=3)
fig2.subplots_adjust(hspace=.5,wspace=0.4)

for i, (ax1,ax2) in enumerate(zip(axes1.flatten(), axes2.flatten())):
    ax1.set_title('day='+str(i))
    ax2.set_title('day='+str(i))
    sc1 = ax1.scatter(x1,y1[:,i], c=c[:,i])
    sc2 = ax2.scatter(x2,y2[:,i], c=c[:,i])
    fig1.colorbar(sc1, ax=ax1)
    fig2.colorbar(sc2, ax=ax2)

plt.savefig("plot.png") 
plt.show()   
plt.close()

Here you loop over the two flattened axes arrays, such that ax1 and ax2 are the matplotlib axes to plot to. fig1 and fig2 are matplotlib figures (matplotlib.figure.Figure).

In order to obtain an index as well, enumerate is used. So the line

for i, (ax1,ax2) in enumerate(zip(axes1.flatten(), axes2.flatten())):
    # loop code

is equivalent here to

for i in range(6):
    ax1 = axes1.flatten()[i]
    ax2 = axes2.flatten()[i]
    # loop code

or

i = 0
for ax1,ax2 in zip(axes1.flatten(), axes2.flatten()):
    # loop code
    i += 1

which are both longer to write.

At this point you may be interested in the fact that althought the above solution using the object oriented API is surely more versatile and preferable, a pure pyplot solution still is possible. This would look like

import numpy as np
import matplotlib.pyplot as plt

x1 = x2 = np.arange(10)
y1 = y2 = c = np.random.rand(10,6)

plt.figure(1)
plt.subplots_adjust(hspace=.5,wspace=0.4)

plt.figure(2)
plt.subplots_adjust(hspace=.5,wspace=0.4)

for i in range(6):
    plt.figure(1)
    plt.subplot(2,3,i+1)
    sc1 = plt.scatter(x1,y1[:,i], c=c[:,i])
    plt.colorbar(sc1)

    plt.figure(2)
    plt.subplot(2,3,i+1)
    sc2 = plt.scatter(x1,y1[:,i], c=c[:,i])
    plt.colorbar(sc2)

plt.savefig("plot.png") 
plt.show()   
plt.close()
ImportanceOfBeingErnest
  • 321,279
  • 53
  • 665
  • 712
  • Thanks for your answer. I understand that the enumerate basically create indexing in the for loop. So is ax1 and ax2 just indexes? Can you please explain what type of object ax1,axes1, and fig1 in this example are? – Dr proctor Oct 05 '17 at 01:14
  • `ax1` and `ax2` are the matplotlib axes. `enumerate` is an easy way to get an index in a loop. Maybe [this](https://www.saltycrane.com/blog/2008/04/how-to-use-pythons-enumerate-and-zip-to/) is of use to understand enumerate better. I updated the answer as well. – ImportanceOfBeingErnest Oct 05 '17 at 08:20
1

Here's a version that shows how to run scatter plots on two different figures. Basically you reference the axes that are created with plt.subplots.

import matplotlib.pyplot as plt
import numpy as np

x1 = y1 = range(10)
x2 = y2 = range(5)

nRows = nCols = 6
fig1, axesArray1 = plt.subplots(nrows=nRows,ncols=nCols,figsize=(20, 20))
fig1.subplots_adjust(hspace=.5,wspace=0.4)
fig1.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)

fig2, axesArray2 = plt.subplots(nrows=nRows,ncols=nCols,figsize=(20, 20))
fig2.subplots_adjust(hspace=.5,wspace=0.4)
fig2.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)

days = range(1, 32)
dayRowCol = np.array([i + 1 for i in range(nRows * nCols)]).reshape(nRows, nCols)
for day in days:
    rowIdx, colIdx = np.argwhere(dayRowCol == day)[0]

    axis1 = axesArray1[rowIdx, colIdx]
    axis1.set_title('day=' + str(day))
    axis1.scatter(x1, y1)

    axis2 = axesArray2[rowIdx, colIdx]
    axis2.set_title('day=' + str(day))
    axis2.scatter(x2, y2)

    # This didn't run in the original script, so I left it as is
    # plt.colorbar().set_label('Distance from ocean',rotation=270)

fig1.savefig('plots/everyday_D1_color.png')
fig2.savefig('plots/everyday_D2_color.png')
plt.close('all')

When I took the original code from the post plt.colorbar() raised an error, so I left it out in the answer. If you have an example of how colorbar was intended to work we could look at how to make that happen for two figures, but the rest of the code should work as intended!

Note that if day every does not appear in dayRolCol numpy will raise an error, it's up to you to decide how you want to handle that case. Also, using numpy is definitely not the only way to do it, just a way I'm comfortable with - all you really need to do is find a way to link a certain day/plot with the (x, y) indices of the axis you want to plot on.

Eric
  • 81
  • 5
  • If you answer a question where there is already an answer, it would be good to make it clear in how far yours is different. Just having twice the same solution there is not useful. Also the other answer shows how to use colorbar, so I don't think it makes sense to state that "we" could have a look at how it works - just you yourself might have a look if you want. – ImportanceOfBeingErnest Oct 04 '17 at 22:19
  • Thanks for your reply. one question is that where did the axis1 and axis2 came from. It doesn't seem like there is any connection to the axesArray1 that you introduced earlier (or is that just a mistake)? – Dr proctor Oct 05 '17 at 01:08
  • @ImportanceOfBeingErnest you're correct, the other answer appeared while I was writing up my answer and I didn't see it until after – Eric Oct 05 '17 at 14:53
  • @Drproctor `axis1` is created from `axesArray1` like so: `axis1 = axesArray1[rowIdx, colIdx]`. `axis2` is made in a similar way – Eric Oct 05 '17 at 14:54