I am rewriting some code from pure Python to JAX. I have gotten to the point where in my old code, I was using Python's multiprocessing module to parallelize the evaluation of a function over all of the CPU cores in a single node as follows:
# start pool process
pool = multiprocessing.Pool(processes=10) # if node has 10 CPU cores, start 10 processes
# use pool.map to evaluate function(input) for each input in parallel
# suppose len(inputs) is very large and 10 inputs are processed in parallel at a time
# store the results in a list called out
out = pool.map(function,inputs)
# close pool processes to free memory
pool.close()
pool.join()
I know that JAX has vmap and pmap, but I don't understand if either of those are a drop-in replacement for how I'm using multiprocessing.pool.map above.
- Is
vmap(function,in_axes=0)(inputs)
distributing to all available CPU cores or what? - How is
pmap(function,in_axes=0)(inputs)
different from vmap and multiprocessing.pool.map? - Is my usage of multiprocessing.pool.map above an example of a "single-program, multiple-data (SPMD)" code that pmap is meant for?
- When I actually do
pmap(function,in_axes=0)(inputs)
I get an error -- ValueError: compiling computation that requires 10 logical devices, but only 1 XLA devices are available (num_replicas=10, num_partitions=1) -- what does this mean? - Finally, my use case is very simple: I merely want to use some/all of the CPU cores on a single node (e.g., all 10 CPU cores on my Macbook). But I have heard about nesting pmap(vmap) -- is this used to parallelize over the cores of multiple connected nodes (say on a supercomputer)? This would be more akin to mpi4py rather than multiprocessing (the latter is restricted to a single node).