I am using the following code to create a collection of color coded line plots:
for j in idlist[i]:
single_traj(lonarray, latarray, parray)
plt.savefig(savename, dpi = 400)
plt.close('all')
plt.clf()
where:
def single_traj(lonarray, latarray, parray, linewidth = 0.7):
"""
Plots XY Plot of one trajectory, with color as a function of p
Helper Function for DrawXYTraj
"""
global lc
x = lonarray
y = latarray
p = parray
points = np.array([x,y]).T.reshape(-1,1,2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
lc = col.LineCollection(segments, cmap=plt.get_cmap('Spectral'),
norm=plt.Normalize(100, 1000), alpha = 0.8)
lc.set_array(p)
lc.set_linewidth(linewidth)
plt.gca().add_collection(lc)
Somehow, this loop uses a lot of memory (> ~10GB), which is still being used after the plot is saved. I used hpy to look at memory usage
Partition of a set of 27472988 objects. Total size = 10990671168 bytes.
Index Count % Size % Cumulative % Kind (class / dict of class)
0 8803917 32 9226505016 84 9226505016 84 dict of matplotlib.path.Path
1 8888542 32 711083360 6 9937588376 90 numpy.ndarray
2 8803917 32 563450688 5 10501039064 96 matplotlib.path.Path
3 11 0 219679112 2 10720718176 98 guppy.sets.setsc.ImmNodeSet
4 25407 0 77593848 1 10798312024 98 list
5 89367 0 28232616 0 10826544640 99 dict (no owner)
6 7642 0 25615984 0 10852160624 99 dict of matplotlib.collections.LineCollection
7 15343 0 16079464 0 10868240088 99 dict of
matplotlib.transforms.CompositeGenericTransform
8 15327 0 16062696 0 10884302784 99 dict of matplotlib.transforms.Bbox
9 53741 0 15047480 0 10899350264 99 dict of weakref.WeakValueDictionary
At this point the plot is already saved, so all matplotlib related objects should be gone... But I cant "find" these objects, which means I don't know how to delete them.
EDIT:
Here is a stand-alone example which reproduces the leak (savefig throws an error for some reason but isn't relevant anyway):
# Memory leak test!
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.collections as col
def draw():
x = range(1000)
y = range(1000)
p = range(1000)
fig = plt.figure(figsize = (12,8))
ax = plt.gca()
ax.set_aspect('equal')
for i in range(1000):
if i%100 == 0:
print i
points = np.array([x,y]).T.reshape(-1,1,2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
lc = col.LineCollection(segments, cmap=plt.get_cmap('Spectral'),
norm=plt.Normalize(0, 1000), alpha = 0.8)
lc.set_array(p)
lc.set_linewidth(0.7)
plt.gca().add_collection(lc)
cb = fig.colorbar(lc, shrink = 0.7)
cb.set_label('p')
cb.ax.invert_yaxis()
plt.tight_layout()
#plt.savefig('./mem_test.png', dpi = 400)
plt.close('all')
plt.clf()
draw()
a = input('Wait...')
The draw() function should delete all plt objects, but they still use up memory after the function is called. I just check it with top/htop!