1

I want to store a dict with many numpy arrays and share it across processes.

import ctypes
import multiprocessing
from typing import Dict, Any

import numpy as np

dict_of_np: Dict[Any, np.ndarray] = multiprocessing.Manager().dict()


def get_numpy(key):
    if key not in dict_of_np:
        shared_array = multiprocessing.Array(ctypes.c_int32, 5)
        shared_np = np.frombuffer(shared_array.get_obj(), dtype=np.int32)
        dict_of_np[key] = shared_np
    return dict_of_np[key]


if __name__ == "__main__":
    a = get_numpy("5")
    a[1] = 5
    print(a)  # prints [0 5 0 0 0]
    b = get_numpy("5")
    print(b)  # prints [0 0 0 0 0]

I followed the instructions in Use numpy array in shared memory for multiprocessing to create the numpy arrays using a buffer, but when I try to save the resulting numpy array in a dict, it doesn't work. As you can see above, changes to a numpy array don't get saved when accessing the dict again using the key.

How can I share a dict of numpy arrays? I need both the dict and the arrays to be shared and use the same memory.

ike
  • 157
  • 1
  • 1
  • 14
  • I found https://stackoverflow.com/questions/63753866/python-multiprocessing-how-to-modify-a-dictionary-created-in-the-main-process-f which has the same issue, but the solution described there is to copy the numpy array when modified, which doesn't work for me. The point of using numpy is to avoid any copying other than the elements modified. – ike Jan 26 '21 at 23:41

1 Answers1

1

based on our discussion from this question I may have come up with a solution: By using a thread in the main process to handle the instantiation of multiprocessing.shared_memory.SharedMemory objects, you can ensure a reference to the shared memory object sticks around, and the underlying memory isn't deleted too early. This only solves the problem specifically with windows where the file is deleted when no more references to it exist. It does not solve the problem of requiring each open instance to be held onto as long as the underlying memoryview is needed.

This manager thread "listens" for messages on an input multiprocessing.Queue, and creates / returns data about shared memory objects. A lock is used to make sure the response is read by the correct process (otherwise responses may get mixed up).

All shared memory objects are first created by the main process, and held onto until explicitly deleted so that other processes may access them.

example:

import multiprocessing
from multiprocessing import shared_memory, Queue, Process, Lock
from threading import Thread
import numpy as np

class Exit_Flag: pass
 
class SHMController:
    def __init__(self):
        self._shm_objects = {}
        self.mq = Queue() #message input queue
        self.rq = Queue() #response output queue
        self.lock = Lock() #only let one child talk to you at a time
        self._processing_thread = Thread(target=self.process_messages)
    
    def start(self): #to be called after all child processes are started
        self._processing_thread.start()
        
    def stop(self):
        self.mq.put(Exit_Flag())
        
    def __enter__(self):
        self.start()
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()
    
    def process_messages(self):
        while True:
            message_obj = self.mq.get()
            if isinstance(message_obj, Exit_Flag):
                break
            elif isinstance(message_obj, str):
                message = message_obj
                response = self.handle_message(message)
                self.rq.put(response)
        self.mq.close()
        self.rq.close()
    
    def handle_message(self, message):
        method, arg = message.split(':', 1)
        if method == "exists":
            if arg in self._shm_objects: #if shm.name exists or not
                return "ok:true"
            else:
                return "ok:false"
        if method == "size":
            if arg in self._shm_objects:
                return f"ok:{len(self._shm_objects[arg].buf)}"
            else:
                return "ko:-1"
        if method == "create":
            args = arg.split(",") #name, size or just size
            if len(args) == 1:
                name = None
                size = int(args[0])
            elif len(args) == 2:
                name = args[0]
                size = int(args[1])
            if name in self._shm_objects:
                return f"ko:'{name}' already created"
            else:
                try:
                    shm = shared_memory.SharedMemory(name=name, create=True, size=size)
                except FileExistsError:
                    return f"ko:'{name}' already exists"
                self._shm_objects[shm.name] = shm
                return f"ok:{shm.name}"
        if method == "destroy":
            if arg in self._shm_objects:
                self._shm_objects[arg].close()
                self._shm_objects[arg].unlink()
                del self._shm_objects[arg]
                return f"ok:'{arg}' destroyed"
            else:
                return f"ko:'{arg}' does not exist"
    
def create(mq, rq, lock):
    #helper functions here could make access less verbose
    with lock:
        mq.put("create:key123,8")
        response = rq.get()
    print(response)
    if response[:2] == "ok":
        name = response.split(':')[1]
        with lock:
            mq.put(f"size:{name}")
            response = rq.get()
        print(response)
        if response[:2] == "ok":
            size = int(response.split(":")[1])
            shm = shared_memory.SharedMemory(name=name, create=False, size=size)
        else:
            print("Oh no....")
            return
    else:
        print("Uh oh....")
        return
    arr = np.ndarray((2,), buffer=shm.buf, dtype=np.int32)
    arr[:] = (1,2)
    print(arr)
    shm.close()
    
def modify(mq, rq, lock):
    while True: #until the shm exists
        with lock:
            mq.put("exists:key123")
            response = rq.get()
        if response == "ok:true":
            print("key:exists")
            break
    with lock:
        mq.put("size:key123")
        response = rq.get()
    print(response)
    if response[:2] == "ok":
        size = int(response.split(":")[1])
        shm = shared_memory.SharedMemory(name="key123", create=False, size=size)
    else:
        print("Oh no....")
        return
    arr = np.ndarray((2,), buffer=shm.buf, dtype=np.int32)
    arr[0] += 5
    print(arr)
    shm.close()
    
def delete(mq, rq, lock):
    pass #TODO make a test for this?

 
if __name__ == "__main__":
    multiprocessing.set_start_method("spawn") #because I'm mixing threads and processes
    with SHMController() as controller:
        mq, rq, lock = controller.mq, controller.rq, controller.lock
        create_task = Process(target=create, args=(mq, rq, lock))
        create_task.start()
        create_task.join()
        modify_task = Process(target=modify, args=(mq, rq, lock))
        modify_task.start()
        modify_task.join()
    print("finished")

In order to solve the problem of each shm staying alive as long as the array does, you must keep a reference to that specific shm object. Keeping a reference alongside the array is fairly straightforward by attaching it as an attribute to a custom array subclass (copied from the numpy guide to subclassing)

class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array

    def __new__(cls, input_array, shm=None):
        obj = np.asarray(input_array).view(cls)
        obj.shm = shm
        return obj

    def __array_finalize__(self, obj):
        if obj is None: return
        self.shm = getattr(obj, 'shm', None)
#example
shm = shared_memory.SharedMemory(name=name)
np_array = SHMArray(np.ndarray(shape, buffer=shm.buf, dtype=np.int32), shm)
Aaron
  • 10,133
  • 1
  • 24
  • 40
  • My current issue is that accessing a numpy array backed by a buffer created with a shared memory name is segfaulting. I think I'm going to have to file a bug report. – ike Feb 02 '21 at 01:16
  • The weirdest thing is it works fine in a single function but breaks when I return the numpy array from a different function. This doesn't have to do with references anymore. – ike Feb 02 '21 at 01:17
  • This answer only addresses how to keep a reference around to ensure windows doesn't de-allocate the block of shared memory (which it does automatically when the reference count reaches 0 rather than waiting for `unlink` to be called) My answer to the previous question addresses keeping a reference to each `shm` being necessary to not segfault on accessing the numpy array using said buffer. If you return an array, you must also return the shm, or it will be garbage collected when the function returns. This is why I suggested subclassing `ndarray` to keep an internal reference to the shm. – Aaron Feb 02 '21 at 03:04
  • This segfault doesn't appear to be related to references; my current test code keeps a reference locally scoped at all times. – ike Feb 02 '21 at 03:16
  • 1
    the code on your github comment does not. `shm` inside your function gets GC'd when the function returns. – Aaron Feb 02 '21 at 03:35
  • 1
    Ah, I had thought you just needed a reference to the original shm, not a reference to the particular shm used to create the array. Now everything makes sense. Thanks! – ike Feb 02 '21 at 03:42
  • 1
    kinda two separate problems, yes. keeping the original one prevents windows from deleting it (which will happen as soon as no references exist), but keeping the one used to create the array is what prevents segfault. – Aaron Feb 02 '21 at 03:45