0

I'm trying to use a cache shared by multiple processes, using multiprocessing.Manager's dict. The following demo gives some context (adopted from this answer):

import multiprocessing as mp
import time

def foo_pool(x, cache):
    if x not in cache:
        time.sleep(2)
        cache[x] = x*x
    else:
        print('using cache for', x)
    return cache[x]

result_list = []
def log_result(result):
    result_list.append(result)

def apply_async_with_callback():
    manager = mp.Manager()
    cache = manager.dict()
    pool = mp.Pool()
    jobs = list(range(10)) + list(range(10))
    for i in jobs:
        pool.apply_async(foo_pool, args = (i, cache), callback = log_result)
    pool.close()
    pool.join()
    print(result_list)

if __name__ == '__main__':
    apply_async_with_callback()

Running the above code gives something like this:

using cache for 0
using cache for 2
using cache for 4
using cache for 1
using cache for 3
using cache for 5
using cache for 7
using cache for 6
[25, 16, 4, 1, 9, 0, 36, 49, 0, 4, 16, 1, 9, 25, 49, 36, 64, 81, 81, 64]

So the cache is working as expected.

What I'd like to achieve is to give a size limit to this manager.dict(), like the maxsize argument for the functools.lru_cache. My current attempt is:

class LimitedSizeDict:
    def __init__(self, max_size):
        self.max_size = max_size
        self.manager = mp.Manager()
        self.dict = self.manager.dict()
        self.keys = self.manager.list()

    def __getitem__(self, key):
        return self.dict[key]

    def __setitem__(self, key, value):
        if len(self.keys) >= self.max_size:
            oldest_key = self.keys.pop(0)
            del self.dict[oldest_key]
        self.keys.append(key)
        self.dict[key] = value

    def __contains__(self, key):
        return key in self.dict

    def __len__(self):
        return len(self.dict)

    def __iter__(self):
        for key in self.keys:
            yield key

Then use the following to launch the processes:

def apply_async_with_callback():
    cache = LimitedSizeDict(3)
    pool = mp.Pool()
    jobs = list(range(10)) + list(range(10))
    for i in jobs:
        pool.apply_async(foo_pool, args = (i, cache), callback = log_result)
    pool.close()
    pool.join()
    print(result_list)

But this gives me an empty list: [].

I thought I probably have to subclass the multiprocessing.managers.DictProxy class to achieve this, so I looked into the source code. But there doesn't seem to be class definition of DictProxy.

How to give a size limit to this shared dict cache? Thanks in advance.

Jason
  • 2,950
  • 2
  • 30
  • 50

1 Answers1

1

First of all, I would define LimitedSizeDict so that it is not coupled to multiprocessing but rather could be a standalone class. Therefore it should not have any references to a "manager" or "managed objects". Second, I would define an iterator class for the class since your current implementation is based on a generator and generators cannot be pickled across processes. Third, there is a way of generating a proxy for just about any arbitrary class as in the following code:

from multiprocessing import Process
from multiprocessing.managers import NamespaceProxy, BaseManager
import inspect
from collections import deque
from threading import Lock

class LimitedSizeDict():
    class Iter:
        def __init__(self, cache):
            self._cache = cache
            self._index = 0

        def __next__(self):
            if self._index >= len(self._cache):
                raise StopIteration
            key = self._cache._get_key(self._index)
            self._index += 1
            return key

    def __init__(self, max_size):
        self._max_size = max_size
        self._d = {}
        self._keys = deque(maxlen=max_size)
        # When not being used with multiprocessing:
        self._proxy = self
        self._lock = Lock()

    def __len__(self):
        return len(self._keys)

    def __getitem__(self, key):
        return self._d[key]

    def __setitem__(self, key, value):
        with self._lock:
            # key may already exist:
            if key not in self._d:
                if len(self._keys) == self._max_size:
                    oldest_key = self._keys[0]
                    del self._d[oldest_key]
                # This automatically will automatically remove self.keys[0]
                self._keys.append(key)
            self._d[key] = value

    # Required by iterator:
    def _get_key(self, index):
        return self._keys[index]

    def __iter__(self):
        return LimitedSizeDict.Iter(self._proxy)

    # When used in multiprocessing
    def _set_proxy(self, proxy):
        self._proxy = proxy

def worker(cache):
    cache['a'] = 1
    cache['b'] = 2
    cache['c'] = 3
    cache['d'] = 4
    for key in cache:
        print(key, cache[key])


class ObjProxy(NamespaceProxy):
    """Returns a proxy instance for any user defined data-type. The proxy instance will have the namespace and
    functions of the data-type (except private/protected callables/attributes). Furthermore, the proxy will be
    pickable and can its state can be shared among different processes. """

    @classmethod
    def populate_obj_attributes(cls, real_cls):
        DISALLOWED = set(dir(cls))
        DISALLOWED.add('__class__')
        ALLOWED = ['__sizeof__', '__eq__', '__ne__', '__le__', '__repr__', '__dict__', '__lt__',
                   '__gt__']
        new_dict = {}
        for (attr, value) in inspect.getmembers(real_cls, callable):
            if attr not in DISALLOWED or attr in ALLOWED:
                new_dict[attr] = cls.proxy_wrap(attr)
        return new_dict

    @staticmethod
    def proxy_wrap(attr):
        """ This method creates function that calls the proxified object's method."""
        def f(self, *args, **kwargs):

            # _callmethod is the method that proxies provided by multiprocessing use to call methods in the proxified object
            return self._callmethod(attr, args, kwargs)

        return f


# Create a class during runtime
LimitedSizeDictProxy = type("LimitedSizeDictProxy", (ObjProxy,), ObjProxy.populate_obj_attributes(LimitedSizeDict))


if __name__ == '__main__':
    BaseManager.register('LimitedSizeDict', LimitedSizeDict, LimitedSizeDictProxy, exposed=tuple(dir(LimitedSizeDictProxy)))
    with BaseManager() as manager:
        cache = manager.LimitedSizeDict(3)
        # Store the proxy in the actual object:
        cache._set_proxy(cache)
        p = Process(target=worker, args=(cache,))
        p.start()
        p.join()

Prints:

b 2
c 3
d 4
Booboo
  • 38,656
  • 3
  • 37
  • 60