I am using a machine learning algorithm (SVM) to perform a classification on a satellite image. The image I am using is 7GB. So, I need to use the multiprocessing
python module to speed up the computation time. I have read all the posts in Stack Overflow
as well as the documentation of multiprocessing
module about how to use it. What I found out wa that my code got really slow after using the pool
method. Obviously, I do something really wrong and I have no idea what it is. Below, I have the code using multiprocessing
.
img = '/home/SVM_seaIce_types/subset_of_subset_calibration_201605_Polar_stereographic_ratio.tif'
def predict():
new_shape = (img.shape[0] * img.shape[1], img.shape[2]-2 )
#reshape the image from 3d to 2d in order to use it for classification
img_as_array = img[:, :, 2:].reshape(new_shape)
print('Reshaped from {o} to {n}'.format(o=img.shape,
n=img_as_array.shape))
# Now predict for each pixel
class_prediction = svm.predict(img_as_array)
# Reshape the image and produce the classification map
class_prediction = class_prediction.reshape(img[:, :, 0].shape)
return class_prediction
if __name__ == '__main__':
start = time.time()
pool = Pool(processes=5)
result = pool.apply_async(predict)
print result.get()
end = time.time()
print 'the processing time is',(end - start)
Below is the code without using multiprocessing
img = '/home/SVM_seaIce_types/subset_of_subset_calibration_201605_Polar_stereographic_ratio.tif'
start = time.time()
new_shape = (img.shape[0] * img.shape[1], img.shape[2]-2 )
img_as_array = img[:, :, 2:].reshape(new_shape)
print('Reshaped from {o} to {n}'.format(o=img.shape,
n=img_as_array.shape))
# Now predict for each pixel
class_prediction = svm.predict(img_as_array)
# Reshape our classification map
class_prediction = class_prediction.reshape(img[:, :, 0].shape)
print class_prediction
end = time.time()
print 'the processing time is:', end - start
By looking at the processing time, we see big difference. I do not understand what is going on. I may have not understood well how multiprocessing
module works. That's why I need your help.
By the way, the array filled with the number 1 corresponds to a class 1 after my model's prediction.
I am doing some editing to my post after getting some useful feedbacks:
Thank you everyone for your responses on my problem. I followed the commends you have made and changed my code accordingly. Now I can see my code running faster after using multiprocessing module properly. The problem now is that I do not get the results I am expecting. Here is my code
img = '/home/john/desktop/seaIce.tif' #shape of image: 500 x 500 x 5 (row,cols,bands)
#reshape image 3d(rows x columns x number of bands) into 2d (total size x number of bands)
image = (img.shape[0] * img.shape[1], img.shape[2]-2 )
tfs2d = img[:, :, 2:].reshape(image)
#find data without nan values
gpi = np.isfinite(tfs2d.sum(axis=0))
tfsgood = tfs2d[gpi, :]
#use svm model to clasify our image
def predict_class(input_data):
prediction = svm.predict(input_data)
return prediction
def main():
#slice the image into chunks
chunk_size=100
chunks = [tfsgood[i:i+chunk_size, :]
for i in xrange(0, tfsgood.shape[0], chunk_size)]
#use multiprocessing module
pool = Pool(6)
svm_labelsgood = pool.map(predict_class, chunks)
# join the results
svm_labelsgood = np.dstack(svm_labelsgood)
svmlabelsall = np.zeros(tfs2d.shape[0])
svmlabelsall[gpi] = svm_labelsgood
#reshape the image so we can display it with matplotlib
reshape = svmlabelsall.reshape(img.shape[0], img.shape[1])
print img.shape
print reshape.shape
plt.imshow(reshape)
plt.show()
if __name__ == '__main__':
main()
The classification image I get is this:
Instead of getting this:
The image on the bottom is the final classification image consisted of 3 classes (sea ice types). The thing is that I cannot wait 10 hours until I get the classification results. What do I do wrong when applying the multiprocessing
module? I am staring at my code for hours and cannot understand why I get a blue image.