My goal is to produce multiple png files from respectively multiple numpy arrays, loaded from medical images in my HD. To make things quicker, I'm using dask delayed. Here's my working code:
import os.path
from glob import glob
import nibabel as nib
import numpy as np
from dask import delayed
def process(data):
# Need to have the import inside so that multiprocessing works.
# Apparently doesn't solve the issue anyway..
import matplotlib.pyplot as plt
outpath = '/Users/user/outputdir/'
name = os.path.basename(data.get_filename())
savename = name[:name.index('.')] + '.png'
plt.imshow(np.rot90(data.get_data()[15:74, 6:82, 18, 0]),
extent=[0, 1, 0, 1], aspect=1.28, cmap='gray')
plt.axis('off')
out = os.path.join(outpath, savename)
plt.savefig(out)
plt.close()
return out
L = []
for fn in glob("/Users/user/imagefiles/mb*.nii.gz"):
nifti = delayed(nib.load)(fn)
outpng = delayed(process)(nifti)
L.append(outpng)
results = delayed(print)(L)
results.compute()
My problem is that after each run some of the output images are empty (nothing in the png), and which images are empty seem pretty random, since all input data is valid.
I suspect this is a problem of multiprocessing and matplotlib, as seen in other related thread.
Does anyone have a suggestion on how to get this working with dask
?
EDIT: Minimal working example
import os.path
import random
import string
import numpy as np
from dask import delayed
def gendata(fn):
return
def process(data):
# Need to have the import inside so that multiprocessing works.
import matplotlib.pyplot as plt
outpath = '/Users/user/Pictures/test/'
name = ''.join(random.choices(string.ascii_lowercase, k=10))
savename = name + '.png'
data = np.random.randint(0, 255, size=(100,100,20,2))
plt.imshow(np.rot90(data[15:74, 6:82, 18, 0]),
extent=[0, 1, 0, 1], aspect=1.28, cmap='gray')
plt.axis('off')
out = os.path.join(outpath, savename)
plt.savefig(out)
plt.close()
return out
L = []
for fn in range(0, 10):
nifti = delayed(gendata)(fn)
outpng = delayed(process)(nifti)
L.append(outpng)
results = delayed(print)(L)
results.compute()