I have two classes. One called algorithm
and the other called Chain
. In algorithm
, I create multiple chains, which are going to be a sequence of sampled values. I want to run the sampling in parallel at the chain level.
In other words, the algorithm
class instantiates n chains and I want to run the _sample
method, which belongs to the Chain
class, for each of the chains in parallel within the algorithm
class.
Below is a sample code that attempts what I would like to do.
I have seen a similar questions here: Apply a method to a list of objects in parallel using multi-processing, but as shown in the function _sample_chains_parallel_worker
, this method does not work for my case (I am guessing it is because of the nested class structure).
Question 1: Why does this not work for this case?
The method in _sample_chains_parallel
also does not even run in parallel.
Question 2: Why?
Question 3: How do I sample each of these chains in parallel?
import time
import multiprocessing
class Chain():
def __init__(self):
self.thetas = []
def _sample(self):
for i in range(3):
time.sleep(1)
self.thetas.append(i)
def clear_thetas(self):
self.thetas = []
class algorithm():
def __init__(self, n=3):
self.n = n
self.chains = []
def _init_chains(self):
for _ in range(self.n):
self.chains.append(Chain())
def _sample_chains(self):
for chain in self.chains:
chain.clear_thetas()
chain._sample()
def _sample_chains_parallel(self):
pool = multiprocessing.Pool(processes=self.n)
for chain in self.chains:
chain.clear_thetas()
pool.apply_async(chain._sample())
pool.close()
pool.join()
def _sample_chains_parallel_worker(self):
def worker(obj):
obj._sample()
pool = multiprocessing.Pool(processes=self.n)
pool.map(worker, self.chains)
pool.close()
pool.join()
if __name__=="__main__":
import time
alg = algorithm()
alg._init_chains()
start = time.time()
alg._sample_chains()
end = time.time()
print "sequential", end - start
start = time.time()
alg._sample_chains_parallel()
end = time.time()
print "parallel", end - start
start = time.time()
alg._sample_chains_parallel_worker()
end = time.time()
print "parallel, map and worker", end - start