0

I'm writing a machine learning program with the following components:

  1. A shared "Experience Pool" with a binary-tree-like data structure.

  2. N simulator processes. Each adds an "experience object" to the pool every once in a while. The pool is responsible for balancing its tree.

  3. M learner processes that sample a batch of "experience objects" from the pool every few moments and perform whatever learning procedure.

I don't know what's the best way to implement the above. I'm not using Tensorflow, so I cannot take advantage of its parallel capability. More concretely,

  • I first think of Python3's built-in multiprocessing library. Unlike multithreading, however, multiprocessing module cannot have different processes update the same global object. My hunch is that I should use the server-proxy model. Could anyone please give me a rough skeleton code to start with?
  • Is MPI4py a better solution?
  • Any other libraries that would be a better fit? I've looked at celery, disque, etc. It's not obvious to me how to adapt them to my use case.
Ainz Titor
  • 1,497
  • 1
  • 15
  • 22
  • Do you really need access to global memory? Why not have a separate dedicated non-Python in-memory queue (Redis comes to mind, or a graph database if you can JSONify your features) your processes can all write to? This will be easier to scale horizontally and you can optimise that component for reads/writes/etc. along the line. Plus, random sampling isn't built into Python `multiprocessing` queue natively, so this frees you for other solutions. – Akshat Mahajan Jul 27 '17 at 00:08
  • It's not a simple queue. The global shared pool is actually a tree structure that needs dedicated code to add/sample stuff. I've implemented that logic in Python already, but it's single-threaded now. – Ainz Titor Jul 27 '17 at 00:12
  • 1
    Then just write it as a separate module and wrap a server process around it. That solves your `multiprocessing` problem, and still gives you horizontal scaling for free. Also, define 'better fit'. What variable are you trying to optimise? Throughput? Latency? Memory footprint? Low garbage collection? We can't give you solutions until you know what you're trying to improve. Your proposed model can be implemented using any tool; your constraints will define your final solution better than your requirements. – Akshat Mahajan Jul 27 '17 at 00:15
  • For now, I don't really care about latency, etc. I'm having trouble even implementing a working version of "wrap a server process around it". I've read the [docs](https://docs.python.org/3/library/multiprocessing.html#sharing-state-between-processes). It tells me how to share `Array` or `Dict`, but doesn't show anything close to my use case. Should I use [pipe](https://docs.python.org/3/library/multiprocessing.html#pipes-and-queues)? Or [BaseManager](https://docs.python.org/3/library/multiprocessing.html#customized-managers)? A brief skeleton code would be very helpful. Thanks! – Ainz Titor Jul 27 '17 at 00:28
  • 1
    Titor: Using `BaseManager` will get you closer to what you need than `pipe`, as it provides a built-in server out of the box. I'll post some example code in a bit. – Akshat Mahajan Jul 27 '17 at 00:45
  • @AkshatMahajan: this would be awesome. Much appreciated! – Ainz Titor Jul 27 '17 at 02:07

1 Answers1

1

Based on the comments, what you're really looking for is a way to update a shared object from a set of processes that are carrying out a CPU-bound task. The CPU-bounding makes multiprocessing an obvious choice - if most of your work was IO-bound, multithreading would have been a simpler choice.

Your problem follows a simpler server-client model: the clients use the server as a simple stateful store, no communication between any child processes is needed, and no process needs to be synchronised.

Thus, the simplest way to do this is to:

  1. Start a separate process that contains a server.
  2. Inside the server logic, provide methods to update and read from a single object.
  3. Treat both your simulator and learner processes as separate clients that can periodically read and update the global state.

From the server's perspective, the identity of the clients doesn't matter - only their actions do.

Thus, this can be accomplished by using a customised manager in multiprocessing as so:

# server.py

from multiprocessing.managers import BaseManager
# this represents the data structure you've already implemented.
from ... import ExperienceTree

# An important note: the way proxy objects work is by shared weak reference to
# the object. If all of your workers die, it takes your proxy object with
# it. Thus, if you have an instance, the instance is garbage-collected
# once all references to it have been erased. I have chosen to sidestep 
# this in my code by using class variables and objects so that instances
# are never used - you may define __init__, etc. if you so wish, but
# just be aware of what will happen to your object once all workers are gone.
class ExperiencePool(object):

    tree = ExperienceTree()

    @classmethod
    def update(cls, experience_object):
        ''' Implement methods to update the tree with an experience object. '''
        cls.tree.update(experience_object)

    @classmethod
    def sample(cls):
        ''' Implement methods to sample the tree's experience objects. '''
        return cls.tree.sample()

# subclass base manager
class Server(BaseManager):
    pass

# register the class you just created - now you can access an instance of 
# ExperiencePool using Server.Shared_Experience_Pool().
Server.register('Shared_Experience_Pool', ExperiencePool)

if __name__ == '__main__':
     # run the server on port 8080 of your own machine
     with Server(('localhost', 8080), authkey=b'none') as server_process:
         server_process.get_server().serve_forever()

Now for all of your clients you can just do:

# client.py - you can always have a separate client file for a learner and a simulator.

from multiprocessing.managers import BaseManager
from server import ExperiencePool

class Server(BaseManager):
     pass

Server.register('Shared_Experience_Pool', ExperiencePool)

if __name__ == '__main__':
     # run the server on port 8080 of your own machine forever.
     server_process = Server(('localhost', 8080), authkey=b'none')
     server_process.connect()
     experience_pool = server_process.Shared_Experience_Pool()
     # now do your own thing and call `experience_call.sample()` or `update` whenever you want. 

You may then launch one server.py and as many workers as you want.

Is This The Best Design?

Not always. You may run into race conditions in that your learners may receive stale or old data if they are forced to compete with a simulator node writing at the same time.

If you want to ensure a preference for latest writes, you may additionally use a lock whenever your simulators are trying to write something, preventing your other processes from getting a read until the write finishes.

Akshat Mahajan
  • 9,543
  • 4
  • 35
  • 44
  • Thanks a lot for the skeleton code. It's very smart to use class variables to sidestep the garbage collection issue. Does that mean there's really no need for the [weakref methods here](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.managers.BaseProxy._callmethod)? It looks like your code serves exactly that purpose without all the reference counting problems. – Ainz Titor Jul 27 '17 at 02:24
  • 1
    A proxy object is just an object. The docs use special objects that are subclassed from BaseProxy because those have better error-handling than just plain old objects, but it's not a crime to use regular objects without that machinery. All that happens under the hood is that a call is evaluated insider the `server`s object, and ferried as a response to the `client`s object - you don't need a special layer to do that. – Akshat Mahajan Jul 27 '17 at 02:30
  • I see. I'm getting `assert self._state.value == State.INITIAL AssertionError` when I run a barebone server script. The error occurs at the `get_server()` call. Port 8080 is available on my computer. – Ainz Titor Jul 27 '17 at 02:43
  • Hmm, the error goes away when I get rid of the context manager and simply call `server_proc = Server(...)`. Any idea why? – Ainz Titor Jul 27 '17 at 02:45
  • 1
    @Ainz Titor: Funny that that should happen. I had the same issue in my client.py, which is why I chose to omit the context manager. Might be worth poking around in the source to figure out what's happening. – Akshat Mahajan Jul 27 '17 at 02:50
  • Quick question: if I want to add locking to ensure that one process is updating the tree at a time, how can I let each client launch see the same global lock object? I know how to do locking with `multiprocess.Process`, but not in the current context. – Ainz Titor Jul 27 '17 at 02:55
  • @Ainz: Use the same trick and make the lock object a class variable. Then have someone lock it. – Akshat Mahajan Jul 27 '17 at 02:58
  • going back to our discussion at the beginning: now that there's a minimal working version, what's your recommended python library (if other than `multiprocessing`) to optimize latency? All I care is to minimize communication time to speed up learning. – Ainz Titor Jul 27 '17 at 04:05
  • 1
    Once you're on the same file system, there's little you can do outside of trying to compactify communication (serialize or compress your sent data) to reduce latency there. It's worth your time to profile your code and see what sort of bottlenecks there are. e.g. if CPU-bound, you can take advantage of Cython to reduce interpreter interference, or try using the GPU if you're doing a lot of parallel work that benefits from SIMD. Libraries won't help until you know your bottlenecks. – Akshat Mahajan Jul 27 '17 at 04:43
  • Gotcha. FYI, I try to add a lock to the class variable. It doesn't work because the lock object is not serializable. I also tried to use [SyncManager.Lock()](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.managers.SyncManager.Lock). It doesn't throw an error, but also doesn't serve the functionality. `with server_proc.Lock(): ...` locks nothing. – Ainz Titor Jul 27 '17 at 05:08
  • Probably an issue with implementation - would recommend asking as a separate question :) Try https://stackoverflow.com/questions/2545961/how-to-synchronize-a-python-dict-with-multiprocessing – Akshat Mahajan Jul 27 '17 at 05:17
  • Good idea. I've posted a new question here: https://stackoverflow.com/questions/45342200/how-to-use-syncmanager-lock-or-event-correctly – Ainz Titor Jul 27 '17 at 05:39
  • FYI, I figured out a [workaround](https://stackoverflow.com/questions/45342200/how-to-use-syncmanager-lock-or-event-correctly/45351044#45351044). @Akshat Mahajan – Ainz Titor Jul 27 '17 at 12:38
  • Quick related question: suppose I use `Server.register('experience_pool', lambda : ExperiencePoolObject)` instead of the class variable trick, will that send the entire `ExperiencePoolObject` over the network, instead of the update delta? @Akshat Mahajan – Ainz Titor Jul 27 '17 at 14:20