I have an image with a large number of bands and I want to perform some operation on each band, then output it to a new image. To speed things up, I want to use multiprocessing in a way that every process is working on one image band, but I get a pickling error (TypeError: cannot pickle 'SwigPyObject' object
) and now I don't know if my idea is even possible.
Here is my code:
import os
import time
import itertools
import multiprocessing as mp
import numpy as np
from osgeo import gdal
def worker_function(src, band, dst):
data = src.GetRasterBand(band).ReadAsArray()
# do more here later
print('\t ... wrrrrrrrrmmmmmm ...')
dst.GetRasterBand(band).WriteArray(data)
return
def wrapper(params):
return worker_function(*params)
if __name__ == '__main__':
start = time.time()
cores = 3
img_in = r'c:\Users\myname\input.tif'
img_out = r'c:\Users\myname\output.tif'
### whole bunch of necessary raster creation things here ###
if os.path.exists(img_out):
os.remove(img_out)
ds = gdal.Open(img_in, gdal.GA_ReadOnly)
print(f'Raster Size: {ds.RasterXSize, ds.RasterYSize}')
print(f'Number of bands: {ds.RasterCount}')
print('Creating output file ...')
driver = gdal.GetDriverByName('GTiff')
out = driver.Create(img_out, ds.RasterXSize, ds.RasterYSize, ds.RasterCount, gdal.GDT_Byte)
out.SetProjection(ds.GetProjection())
out.SetGeoTransform(ds.GetGeoTransform())
print(f'Setting up multiprocessing for {cores} cores ...')
mp.freeze_support()
print('Starting the pool ...')
### this is the important part ###
with mp.Pool(processes=cores) as pool:
pool.map(wrapper, zip(itertools.repeat(ds), [b for b in range(1, ds.RasterCount + 1)], itertools.repeat(out)))
out = None
ds = None
print(f'\nDone! Took me {np.round(time.time() - start, 4)}s.')
Since ds
and out
are Swig Object of type 'GDALDatasetShadow *'
, I get that pickling error.
According to the answer in Python multiprocessing PicklingError: Can't pickle <type 'function'> I tried to replace pool.map
with pool.map_async
:
pool.apply_async(worker_function, zip(itertools.repeat(ds), [b for b in range(1, ds.RasterCount + 1)], itertools.repeat(out)))
But then I only get an empty result image, no prints and it finishes after not even half a second and no prints, which indicates to me that worker_function
is not even entered. What am I missing here?