0

I have a problem, which, when simplified:

  1. has a loop which samples new points
  2. evaluates them with a complex/slow function
  3. accepts them if the value is above an ever-increasing threshold.

Here is example code for illustration:

from numpy.random import uniform
from time import sleep

def userfunction(x):
    # do something complicated
    # but computation always takes takes roughly the same time
    sleep(1) # comment this out if too slow
    xnew = uniform() # in reality, a non-trivial function of x
    y = -0.5 * xnew**2
    return xnew, y

x0, cur = userfunction([])
x = [x0] # a sequence of points

while cur < -2e-16:
    # this should be parallelised

    # search for a new point higher than a threshold
    x1, next = userfunction(x)
    if next <= cur:
        # throw away (this branch is taken 99% of the time)
        pass
    else:
        cur = next
        print cur
        x.append(x1) # note that userfunction depends on x

print x

I want to parallelise this (e.g. across a cluster), but the problem is that I need to terminate the other workers when a successful point has been found, or at least inform them of the new x (if they manage to get above the new threshold with an older x, the result is still acceptable). As long as no point has been successful, I need the workers repeat.

I am looking for tools/frameworks which can handle this type of problem, in any scientific programming language (C, C++, Python, Julia, etc., no Fortran please).

Can this be solved with MPI semi-elegantly? I don't understand how I can inform/interrupt/update workers with MPI.

Update: added code comments to say most tries are unsuccessful and do not influence the variable userfunction depends on.

j13r
  • 2,576
  • 2
  • 21
  • 28
  • In the user function you will have to check once in a while whether a better solution has been found by the other threads. – Serge Rogatch Sep 01 '17 at 12:39
  • @SergeRogatch, wouldn't require N^2 communication? Alternatively, I could make the workers ask the main program about the current x. In my problem, the success of getting a new point typically happens only 1/1000 times, so there would be a lot of useless calls if it is the workers who ask. – j13r Sep 01 '17 at 12:45
  • No, definitely not `N*N` communications. A worker informs the main thread about the best value found. The main thread communicates this event and the value to all the other workers. The other workers check once in a while for this event, and depending on whether they have a better value, they either communicate it to the main thread, or exit. – Serge Rogatch Sep 01 '17 at 12:48
  • Closely related https://stackoverflow.com/questions/43973504/mpi-asynchronous-broadcast-from-unknown-source – Zulan Sep 01 '17 at 13:40
  • Could you maybe start a second thread in each MPI process that runs in parallel with your main code. It would then sit in a loop waiting (blocking) on an MPI message tagged as 'NEWSURVIVOR' and when it gets it, it would change an atomic variable shared with the main thread. The main thread would check that variable each time through its loop. When a new survivor is found, you would just broadcast with a tag 'NEWSURVIVOR'. Just a thought.... – Mark Setchell Sep 01 '17 at 17:55

2 Answers2

0

if userfunction() does not take too long, then here is an option that qualifies for "MPI semi-elegantly"

in order to keep thing simple, let's assume rank 0 is only an orchestrator and does not compute anything.

on rank 0

cur = 0
x = []
while cur < -2e-16:
    MPI_Recv(buf=cur+x1, src=MPI_ANY_SOURCE)
    x.append(x1)
    MPI_Ibcast(buf=cur+x, root=0, request=req)
    MPI_Wait(request=req)

on rank != 0

x0, cur = userfunction([])
x = [x0] # a sequence of points

while cur < -2e-16:
    MPI_Ibcast(buf=newcur+newx, root=0, request=req
    # search for a new point higher than a threshold
    x1, next = userfunction(x)
    if next <= cur:
        # throw away (this branch is taken 99% of the time)
        MPI_Test(request=ret, flag=found)
        if found:
            MPI_Wait(request)   
    else:
        cur = next
        MPI_Send(buffer=cur+x1, dest=0)
        MPI_Wait(request)

extra logic is needed to correctly handle - rank 0 does computation too - several ranks find the solution at the same time, subsequent messages must be consumed by rank 0

strictly speaking, a task is not "interrupted" when a solution is found on an other task. instead, each task check periodically if the solution was found by an other task. so there is a delay between the time a solution if found somewhere and all tasks stop looking for solutions, but if userfunction() does not take "too long", this looks very acceptable to me.

Gilles Gouaillardet
  • 8,193
  • 11
  • 24
  • 30
  • This looks like a good start. What is next()? I think there is an outer loop missing in this solution -- How can I update the workers to inform them of the new array x and new threshold cur? – j13r Sep 01 '17 at 13:59
  • That data dependency is the difficulty. You may assume that the first if branch is taken ~99% of the time, so the dependency is quite loose. However, if the other one is taken, all workers should be informed. – j13r Sep 01 '17 at 14:22
  • userfunction is indeed non-deterministic (it has a uniform() random number call at the start). – j13r Sep 01 '17 at 14:27
  • oh i see, i edited my answer and remove my comments accordingly – Gilles Gouaillardet Sep 01 '17 at 14:34
  • Hmm, I am having trouble understanding the communication. Could you elaborate on the rank=0 code part (It's ok for me if rank=0 does not do any computation)? Specifically, I don't understand what MPI_test does in the loop and when found is set. Also, I guess there be should a loop around both code segments? – j13r Sep 01 '17 at 15:23
  • @j13r i updated the pseudo code, is it more understandable now ? – Gilles Gouaillardet Sep 04 '17 at 01:39
0

I solved it roughly with the following code.

This transmits only curmax at the moment, but one can send the other array with a second broadcast+tag.

import numpy
import time

from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

import logging
logging.basicConfig(filename='mpitest%d.log' % rank,level=logging.DEBUG)
logFormatter = logging.Formatter("[%(name)s %(levelname)s]: %(message)s")
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(logFormatter)
consoleHandler.setLevel(logging.INFO)
logging.getLogger().addHandler(consoleHandler)

log = logging.getLogger(__name__)

if rank == 0:
    curmax = numpy.random.random()
    seq = [curmax]
    log.info('%d broadcasting starting value %f...' % (rank, curmax))
    comm.Ibcast(numpy.array([curmax]))

    was_updated = False
    while True:
        # check if news available
        status = MPI.Status()
        a_avail = comm.iprobe(source=MPI.ANY_SOURCE, tag=12, status=status)
        if a_avail:
            sugg = comm.recv(source=status.Get_source(), tag=12)
            log.info('%d received new limit from %d: %s' % (rank, status.Get_source(), sugg))
            if sugg < curmax:
                curmax = sugg
                seq.append(curmax)
                log.info('%d updating to %s' % (rank, curmax))
                was_updated = True
            else:
                # ignore
                pass
        # check if next message is already waiting:
        if comm.iprobe(source=MPI.ANY_SOURCE, tag=12):
            # consume it first before broadcasting outdated info
            continue

        if was_updated:
            log.info('%d broadcasting new limit %f...' % (rank, curmax))
            comm.Ibcast(numpy.array([curmax]))
            was_updated = False
        else:
            # no message waiting for us and no broadcast done, so pause
            time.sleep(0.1)
        print

    print data, rank
else:
    log.info('%d waiting for root to send us starting value...' % (rank))
    nextmax = numpy.empty(1, dtype=float)
    comm.Ibcast(nextmax).Wait()

    amax = float(nextmax)
    numpy.random.seed(rank)
    update_req = comm.Ibcast(nextmax)
    while True:
        a = numpy.random.uniform()
        if a < amax:
            log.info('%d found new: %s, sending to root' % (rank, a))
            amax = a
            comm.isend(a, dest=0, tag=12)
        s = update_req.Get_status()
        #log.info('%d bcast status: %s' % (rank, s))
        if s:
            update_req.Wait()
            log.info('%d receiving new limit from root, %s' % (rank, nextmax))
            amax = float(nextmax)
            update_req = comm.Ibcast(nextmax)
j13r
  • 2,576
  • 2
  • 21
  • 28