6

I'm trying to create a scrollable multiplot based on the answer to this question: Creating a scrollable multiplot with python's pylab

Lines created using ax.plot() are updating correctly, however I'm unable to figure out how to update artists created using xvlines() and fill_between().

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import Slider

#create dataframes
dfs={}
for x in range(100):
    col1=np.random.normal(10,0.5,30)
    col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]
    col3=np.random.randint(4,size=30)
    dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})

#create figure,axis,subplot
fig = plt.figure()
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)
ax = plt.subplot(gs[0])
ax.set_ylim([0,12])

#slider
frame=0
axframe = plt.axes([0.13, 0.02, 0.75, 0.03])
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')

#plots
ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')

#artists
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)
ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')

#update plots
def update(val):
    frame = np.floor(sframe.val)
    ln1.set_ydata(dfs[frame]['col1'])
    ln2.set_ydata(dfs[frame]['col2'])
    ax.set_title('Frame ' + str(int(frame)))
    plt.draw()

#connect callback to slider 
sframe.on_changed(update)
plt.show()

This is what it looks like at the moment enter image description here

I can't apply the same approach as for plot(), since the following produces an error message:

ln3,=ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
TypeError: 'PolyCollection' object is not iterable

This is what it's meant to look like on each frame enter image description here

Paulie-C
  • 1,674
  • 1
  • 13
  • 29
themachinist
  • 1,413
  • 2
  • 17
  • 22

3 Answers3

2

fill_between returns a PolyCollection, which expects a list (or several lists) of vertices upon creation. Unfortunately I haven't found a way to retrieve the vertices that where used to create the given PolyCollection, but in your case it is easy enough to create the PolyCollection directly (thereby avoiding the use of fill_between) and then update its vertices upon frame change.

Below a version of your code that does what you are after:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import Slider

from matplotlib.collections import PolyCollection

#create dataframes
dfs={}
for x in range(100):
    col1=np.random.normal(10,0.5,30)
    col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]
    col3=np.random.randint(4,size=30)
    dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})

#create figure,axis,subplot
fig = plt.figure()
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)
ax = plt.subplot(gs[0])
ax.set_ylim([0,12])

#slider
frame=0
axframe = plt.axes([0.13, 0.02, 0.75, 0.03])
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')

#plots
ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')

##additional code to update the PolyCollections
val_r = 5
val_b = 8
val_g = 7

def update_collection(collection, value, frame = 0):
    xs = np.array(dfs[frame].index)
    ys = np.array(dfs[frame]['col2'])

    ##we need to catch the case where no points with y == value exist:
    try:
        minx = np.min(xs[ys == value])
        maxx = np.max(xs[ys == value])
        miny = value-0.5
        maxy = value+0.5
        verts = np.array([[minx,miny],[maxx,miny],[maxx,maxy],[minx,maxy]])
    except ValueError:
        verts = np.zeros((0,2))
    finally:
        collection.set_verts([verts])

#artists

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
reds = PolyCollection([],facecolors = ['r'], alpha = 0.5)
ax.add_collection(reds)
update_collection(reds,val_r)

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)
blues = PolyCollection([],facecolors = ['b'], alpha = 0.5)
ax.add_collection(blues)
update_collection(blues, val_b)

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)
greens = PolyCollection([],facecolors = ['g'], alpha = 0.5)
ax.add_collection(greens)
update_collection(greens, val_g)

ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')

#update plots
def update(val):
    frame = np.floor(sframe.val)
    ln1.set_ydata(dfs[frame]['col1'])
    ln2.set_ydata(dfs[frame]['col2'])
    ax.set_title('Frame ' + str(int(frame)))

    ##updating the PolyCollections:
    update_collection(reds,val_r, frame)
    update_collection(blues,val_b, frame)
    update_collection(greens,val_g, frame)

    plt.draw()

#connect callback to slider 
sframe.on_changed(update)
plt.show()

Each of the three PolyCollections (reds, blues, and greens) has only four vertices (the edges of the rectangles), which are determined based on the given data (which is done in update_collections). The result looks like this:

example result of given code

Tested in Python 3.5

Thomas Kühn
  • 9,412
  • 3
  • 47
  • 63
0

Your error

TypeError: 'PolyCollection' object is not iterable

can be avoided by removing the comma after l3:

l3 = ax.fill_between(xx, y1, y2, **kwargs)

The return value is a PolyCollection, you need to update its vertices during the update() function. An alternative to the other answer posted here is to make fill_between() give you a new PolyCollection, and then get its vertices and use them to update those of l3:

def update(val):
    dummy_l3 = ax.fill_between(xx, y1, y2, **kwargs)
    verts = [ path._vertices for path in dummy_l3.get_paths() ]
    codes = [ path._codes for path in dummy_l3.get_paths() ]
    dummy_l3.remove()
    l3.set_verts_and_codes(verts, codes)
    plt.draw()
jasondet
  • 93
  • 1
  • 5
0

The above code does not run for me; however, to refresh fill_between the following works for me

    %matplotlib inline
    
    import numpy as np
    from IPython import display
    import matplotlib.pyplot as plt
    import time
    
    hdisplay   = display.display("", display_id=True)
    
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    x = np.linspace(0,1,100)
    ax.set_title("Test")
    ax.set_xlim()
    y = np.random.random(size=(100))
    dy = 0.1
    l = ax.plot(x,y,":",color="red")
    b = ax.fill_between( x, y-dy, y+dy, color="red", alpha=0.2 )
    hdisplay.update(fig)  
    
    for i in range(5):
        time.sleep(1)
        ax.set_title("Test %ld" % i)
        y = np.random.random(size=(100))
        l[0].set_ydata( y )
        b.remove()
        b = ax.fill_between( x, y-dy, y+dy, color="red", alpha=0.2 )
        plt.draw()
        hdisplay.update(fig)  
    
    plt.close(fig) 
Probably
  • 1
  • 1