2

Let's say I have a function IsAPrimaryColour() which works by calling three other functions IsRed(), IsGreen() and IsBlue(). Since the three functions are quite independent of one another, they can run concurrently. The return conditions are:

  1. If any of the three functions returns true, IsAPrimaryColour() should also return true. There is no need to wait for the other functions to finish. That is: IsPrimaryColour() is true if IsRed() is true OR IsGreen() is true OR IsBlue() is true
  2. If all functions return false, IsAPrimaryColour() should also return false. That is: IsPrimaryColour() is false if IsRed() is false AND IsGreen() is false AND IsBlue() is false
  3. If any of the three functions returns an error, IsAPrimaryColour() should also return the error. There is no need to wait for the other functions to finish, or to collect any other errors.

The thing I'm struggling with is how to exit the function if any other three functions return true, but also to wait for all three to finish if they all return false. If I use a sync.WaitGroup object, I will need to wait for all 3 go routines to finish before I can return from the calling function.

Therefore, I'm using a loop counter to keep track of how many times I have received a message on a channel and existing the program once I have received all 3 messages.

https://play.golang.org/p/kNfqWVq4Wix

package main

import (
    "errors"
    "fmt"
    "time"
)

func main() {
    x := "something"
    result, err := IsAPrimaryColour(x)

    if err != nil {
        fmt.Printf("Error: %v\n", err)
    } else {
        fmt.Printf("Result: %v\n", result)
    }
}

func IsAPrimaryColour(value interface{}) (bool, error) {
    found := make(chan bool, 3)
    errors := make(chan error, 3)
    defer close(found)
    defer close(errors)
    var nsec int64 = time.Now().UnixNano()

    //call the first function, return the result on the 'found' channel and any errors on the 'errors' channel
    go func() {
        result, err := IsRed(value)
        if err != nil {
            errors <- err
        } else {
            found <- result
        }
        fmt.Printf("IsRed done in %f nanoseconds \n", float64(time.Now().UnixNano()-nsec))
    }()

    //call the second function, return the result on the 'found' channel and any errors on the 'errors' channel
    go func() {
        result, err := IsGreen(value)
        if err != nil {
            errors <- err
        } else {
            found <- result
        }
        fmt.Printf("IsGreen done in %f nanoseconds \n", float64(time.Now().UnixNano()-nsec))
    }()

    //call the third function, return the result on the 'found' channel and any errors on the 'errors' channel
    go func() {
        result, err := IsBlue(value)
        if err != nil {
            errors <- err
        } else {
            found <- result
        }
        fmt.Printf("IsBlue done in %f nanoseconds \n", float64(time.Now().UnixNano()-nsec))
    }()

    //loop counter which will be incremented every time we read a value from the 'found' channel
    var counter int

    for {
        select {
        case result := <-found:
            counter++
            fmt.Printf("received a value on the results channel after %f nanoseconds. Value of counter is %d\n", float64(time.Now().UnixNano()-nsec), counter)
            if result {
                fmt.Printf("some goroutine returned true\n")
                return true, nil
            }
        case err := <-errors:
            if err != nil {
                fmt.Printf("some goroutine returned an error\n")
                return false, err
            }
        default:
        }

        //check if we have received all 3 messages on the 'found' channel. If so, all 3 functions must have returned false and we can thus return false also
        if counter == 3 {
            fmt.Printf("all goroutines have finished and none of them returned true\n")
            return false, nil
        }
    }
}

func IsRed(value interface{}) (bool, error) {
    return false, nil
}

func IsGreen(value interface{}) (bool, error) {
    time.Sleep(time.Millisecond * 100) //change this to a value greater than 200 to make this function take longer than IsBlue()
    return true, nil
}

func IsBlue(value interface{}) (bool, error) {
    time.Sleep(time.Millisecond * 200)
    return false, errors.New("something went wrong")
}

Although this works well enough, I wonder if I'm not overlooking some language feature to do this in a better way?

  • 2
    The empty `default` will make this a fast loop that will consume an entire core. – Adrian Aug 26 '21 at 13:28
  • Return a struct from the goroutines containing the result and an error. Then you can simply read three times from the channel. – Burak Serdar Aug 26 '21 at 13:29
  • Your requirement is unclear to me. "how to exit the function if any other three functions return true, but also to wait for all three to finish if they all return false" sounds like a contradiction. It sounds like you want to exit _and_ you want to wait (not exit) at the same time? Can you clarify your goal? – Jonathan Hall Aug 26 '21 at 13:35
  • @Flimzy what i meant was that in case any of the 3 functions return true, then the calling function should return true (without waiting for the other 2 functions to return something). However, the calling function would need to wait for all three functions to return false if it is to return false also. Hope that clarifies. –  Aug 26 '21 at 13:39
  • So you want to return `true` as soon as any one of the three returns `true`, then cancel the others? – Jonathan Hall Aug 26 '21 at 13:57
  • @Flimzy yes indeed –  Aug 26 '21 at 14:06
  • 1
    The proper solution for that is a `context.Context`. I'm sure I've seen a related answer... let me search for it. – Jonathan Hall Aug 26 '21 at 14:11
  • 1
    This is related: https://stackoverflow.com/a/45502591/13860 – Jonathan Hall Aug 26 '21 at 14:12

4 Answers4

3

errgroup.WithContext can help simplify the concurrency here.

You want to stop all of the goroutines if an error occurs, or if a result is found. If you can express “a result is found” as a distinguished error (along the lines of io.EOF), then you can use errgroup's built-in “cancel on first error” behavior to shut down the whole group:

func IsAPrimaryColour(ctx context.Context, value interface{}) (bool, error) {
    var nsec int64 = time.Now().UnixNano()

    errFound := errors.New("result found")
    g, ctx := errgroup.WithContext(ctx)

    g.Go(func() error {
        result, err := IsRed(ctx, value)
        if result {
            err = errFound
        }
        fmt.Printf("IsRed done in %f nanoseconds \n", float64(time.Now().UnixNano()-nsec))
        return err
    })

    …

    err := g.Wait()
    if err == errFound {
        fmt.Printf("some goroutine returned errFound\n")
        return true, nil
    }
    if err != nil {
        fmt.Printf("some goroutine returned an error\n")
        return false, err
    }
    fmt.Printf("all goroutines have finished and none of them returned true\n")
    return false, nil
}

(https://play.golang.org/p/MVeeBpDv4Mn)

bcmills
  • 4,391
  • 24
  • 34
  • Thanks for your answer. Context does indeed seem to be the 'correct' Go way to write such functions. It doesn't work in my specific case as the functions I'm working with are from an external library and I can't edit the vendor's code to handle the cancellation. But very interesting to learn this anyways! –  Aug 26 '21 at 16:14
  • 3
    If you can't edit the vendor's code to handle cancellation, then it is not safe to return early: your program will have unbounded memory consumption if any of the operations stalls. But you can still use an `errgroup.Group` to simplify the aggregation step: https://play.golang.org/p/Xx_tx_bHhaO – bcmills Aug 26 '21 at 16:23
1

some remarks,

  • you dont need to close the channels, you know before hand the expected count of signals to read. This is sufficient for an exit condition.
  • you dont need to duplicate manual function calls, use a slice.
  • since you use a slice, you dont even need a counter, or a static value of 3, just look at the length of your func slice.
  • that default case into the switch is useless. just block on the input you are waiting for.

So once you got ride of all the fat, the code looks like


func IsAPrimaryColour(value interface{}) (bool, error) {
    fns := []func(interface{}) (bool, error){IsRed, IsGreen, IsBlue}
    found := make(chan bool, len(fns))
    errors := make(chan error, len(fns))

    for i := 0; i < len(fns); i++ {
        fn := fns[i]
        go func() {
            result, err := fn(value)
            if err != nil {
                errors <- err
                return
            }
            found <- result
        }()
    }

    for i := 0; i < len(fns); i++ {
        select {
        case result := <-found:
            if result {
                return true, nil
            }
        case err := <-errors:
            if err != nil {
                return false, err
            }
        }
    }
    return false, nil
}
  • you dont need to obsereve the time at the each and every async calls, just observe the time the overall caller took to return.
func main() {
    now := time.Now()
    x := "something"
    result, err := IsAPrimaryColour(x)

    if err != nil {
        fmt.Printf("Error: %v\n", err)
    } else {
        fmt.Printf("Result: %v\n", result)
    }
    fmt.Println("it took", time.Since(now))
}

https://play.golang.org/p/bARHS6c6m1c

  • In that example, if any result is true the subsequent goroutines may leak (undetected) for an arbitrarily long time. (Recall that Context cancellation is asynchronous.) For non-trivial programs, that can result in out-of-memory failures under load, data races in tests, skewed benchmarks, and the like. – bcmills Aug 26 '21 at 15:50
  • @bcmills can you elaborate with an example please. I dont see how it might leak; https://play.golang.org/p/zWB9s6zj76l –  Aug 26 '21 at 15:54
  • If you were talking about the case where the called functions are abnormally slow. Well yes, this might happen. Then, i feel obliged to note that passing a context is only one part of the solution. The function in question might well be lacking context cancellation assertion to return early. Going down that path anything can happen and the language does not give us the tool to ensure a correct behavior. I am arguing because i would like to understand your point... –  Aug 26 '21 at 16:01
  • “Abnormally slow” operations can and do happen in real programs, especially under abnormal load (such as during a partial network outage, or if a service goes viral for some reason). If that happens, you want the program to degrade gracefully, not run out of memory and crash. The simplest way to degrade gracefully is to structure the calls to be synchronous instead of abandoning asynchronous operations in flight. – bcmills Aug 26 '21 at 16:10
  • 1
    Yes..? But it's not “relearning” if we teach people the more robust patterns in the first place. (To that end, I did a whole talk on [rethinking concurrency patterns in Go](https://drive.google.com/file/d/1nPdvhB0PutEJzdCq5ms6UI58dp50fcAN/view).) – bcmills Aug 26 '21 at 16:12
  • thanks for that i will check it out. Sorry if am being bad here, but, reading you, i am eager to see the incoming `sort.SliceWithCtx` TBH i feel a bit sad about this... –  Aug 26 '21 at 16:13
1

The idiomatic way to handle multiple concurrent function calls, and cancel any outstanding after a condition, is with the use of a context value. Something like this:

func operation1(ctx context.Context) bool { ... }
func operation2(ctx context.Context) bool { ... }
func operation3(ctx context.Context) bool { ... }

func atLeastOneSuccess() bool {
    ctx, cancel := context.WithCancel(context.Background()
    defer cancel() // Ensure any functions still running get the signal to stop
    results := make(chan bool, 3) // A channel to send results
    go func() {
        results <- operation1(ctx)
    }()
    go func() {
        results <- operation2(ctx)
    }()
    go func() {
        results <- operation3(ctx)
    }()
    for i := 0; i < 3; i++ {
        result := <-results
        if result {
            // One of the operations returned success, so we'll return that
            // and let the deferred call to cancel() tell any outstanding
            // functions to abort.
            return true
        }
    }
    // We've looped through all return values, and they were all false
    return false
}

Of course this assumes that each of the operationN functions actually honors a canceled context. This answer discusses how to do that.

Jonathan Hall
  • 75,165
  • 16
  • 143
  • 189
  • In that example, if any result is `true` the subsequent goroutines may leak (undetected) for an arbitrarily long time. (Recall that `Context` cancellation is asynchronous.) For non-trivial programs, that can result in out-of-memory failures under load, data races in tests, skewed benchmarks, and the like. – bcmills Aug 26 '21 at 15:49
  • Thanks for your answer. Context does indeed seem to be the 'correct' Go way to write such functions. It doesn't work in my specific case as the functions I'm working with are from an external library and I can't edit the vendor's code to handle the cancellation. But very interesting to learn this anyways! –  Aug 26 '21 at 16:14
  • If you can't abort the operation, then you only have two choices: Wait until all the funtions have returned, or ignore the fact that some are still running, and let them continue in the background, potentially consuming more resources than you expect. – Jonathan Hall Aug 26 '21 at 18:00
0

You don't have to block the main goroutine on the Wait, you could block something else, for example:

doneCh := make(chan struct{}{})

go func() {
    wg.Wait()
    close(doneCh)
}()

Then you can wait on doneCh in your select to see if all the routines have finished.

Adrian
  • 42,911
  • 6
  • 107
  • 99
  • thanks for the answer. One further question though just for my clarification, I though wg.Wait() would return only when all the go routines have finished executing (or rather when all have decremented the waitgroup counter). To use my earlier example, if IsRed() returns true, wouldn't wg.Wait() still wait for IsBlue() and IsGreen() to finish? –  Aug 26 '21 at 13:46
  • Yes. Which is why you'd use it in the `select` - that way you can concurrently check if all routines have finished, or if one of them has returned true, or if one of them has returned an error. – Adrian Aug 26 '21 at 14:10
  • This approach still leaks the goroutines that are in flight if one of the other branches of the `select` returns — it simplifies the loop, but still risks running out of memory if any of the other operations consistently stalls. – bcmills Aug 26 '21 at 16:17
  • Thanks @Adrian. I've selected this as the answer as this approach does not require me to modify the functions. This is especially useful to me as the functions I'm working with are from an external library and I can't edit them to add Context. –  Aug 26 '21 at 16:17
  • @bcmills, true there is very much the possibility that the go routines will continue to run in the background, thereby leaking memory. In my particular case, this is not a huge problem as the functions don't run for too long (they don't do any I/O, external calls or have iterations within them) so I can live with the overhead. Thanks! –  Aug 26 '21 at 16:20
  • @bcmills it doesn't exactly *leak* goroutines, since they will end eventually, but it could - depending on the overall implementation context - lead to many outstanding goroutines, causing memory exhaustion. It's definitely not the best solution, but it's the smallest change. – Adrian Aug 26 '21 at 17:08
  • There is no guarantee that they will end eventually. If you were to write a unit-test for the `IsAPrimaryColour` function, the test would pass even if those goroutines deadlock and never return in the early-exit case. – bcmills Aug 26 '21 at 17:35
  • @bcmills the quoted goroutines will definitely end eventually. All they do is sleep and return. The real-world routines may or may not, but they're not shown, and unless they contain defects, they will return eventually. – Adrian Aug 26 '21 at 18:34
  • 1
    The “unless they contain defects” part is a heavy lift: real programs often *do* contain defects. If you consistently wait for goroutines to finish, then you can detect those defects fairly easily during testing — they're right there in the goroutine dump when the test times out. If you don't, then you're left trying to sift through a much larger core dump when your program OOMs in production. – bcmills Aug 27 '21 at 15:39