2

I'm attempting to speed up a multivariate fixed-point iteration algorithm using multiprocessing however, I'm running issues dealing with shared data. My solution vector is actually a named dictionary rather than a vector of numbers. Each element of the vector is actually computed using a different formula. At a high level, I have an algorithm like this:

current_estimate = previous_estimate
while True:
for state in all_states:
    current_estimate[state] = state.getValue(previous_estimate)
if norm(current_estimate, previous_estimate) < tolerance:
    break
else:
    previous_estimate, current_estimate = current_estimate, previous_estimate

I'm trying to parallelize the for-loop part with multiprocessing. The previous_estimate variable is read-only and each process only needs to write to one element of current_estimate. My current attempt at rewriting the for-loop is as follows:

# Class and function definitions
class A(object):
    def __init__(self,val):
        self.val = val

    # representative getValue function
    def getValue(self, est):
        return est[self] + self.val

def worker(state, in_est, out_est):
    out_est[state] = state.getValue(in_est)

def worker_star(a_b_c):
    """ Allow multiple arguments for a pool
        Taken from http://stackoverflow.com/a/5443941/3865495
    """
    return worker(*a_b_c)

# Initialize test environment
manager = Manager()
estimates = manager.dict()
all_states = []
for i in range(5):
     a = A(i)
     all_states.append(a)
     estimates[a] = 0

pool = Pool(process = 2)
prev_est = estimates
curr_est = estimates
pool.map(worker_star, itertools.izip(all_states, itertools.repeat(prev_est), itertools.repreat(curr_est)))

The issue I'm currently running into is that the elements added to the all_states array are not the same as those added to the manager.dict(). I keep getting key value errors when trying to access elements of the dictionary using elements of the array. And debugging, I found that none of the elements are the same.

print map(id, estimates.keys())
>>> [19558864, 19558928, 19558992, 19559056, 19559120]
print map(id, all_states)
>>> [19416144, 19416208, 19416272, 19416336, 19416400]
CoconutBandit
  • 476
  • 1
  • 3
  • 13

1 Answers1

1

This is happening because the objects you're putting into the estimates DictProxy aren't actually the same objects as those that live in the regular dict. The manager.dict() call returns a DictProxy, which is proxying access to a dict that actually lives in a completely separate manager process. When you insert things into it, they're really being copied and sent to a remote process, which means they're going to have a different identity.

To work around this, you can define your own __eq__ and __hash__ functions on A, as described in this question:

class A(object):
    def __init__(self,val):
        self.val = val

    # representative getValue function
    def getValue(self, est):
        return est[self] + self.val

    def __hash__(self):
        return hash(self.__key())

    def __key(self):
        return (self.val,)

    def __eq__(x, y):
        return x.__key() == y.__key()

This means the key look ups for items in the estimates will just use the value of the val attribute to establish identity and equality, rather than the id assigned by Python.

Community
  • 1
  • 1
dano
  • 91,354
  • 19
  • 222
  • 219