2

I am trying to implement an asynchronous method that takes an array of ChannelReader<T>s, and takes a value from any of the channels that has an item available. It is a method with similar functionality with the BlockingCollection<T>.TakeFromAny method, that has this signature:

public static int TakeFromAny(BlockingCollection<T>[] collections, out T item,
    CancellationToken cancellationToken);

This method returns the index in the collections array from which the item was removed. An async method cannot have out parameters, so the API that I am trying to implement is this:

public static Task<(T Item, int Index)> TakeFromAnyAsync<T>(
    ChannelReader<T>[] channelReaders,
    CancellationToken cancellationToken = default);

The TakeFromAnyAsync<T> method should read asynchronously an item, and return the consumed item along with the index of the associated channel in the channelReaders array. In case all the channels are completed (either successfully or with an error), or all become complete during the await, the method should throw asynchronously a ChannelClosedException.

My question is: how can I implement the TakeFromAnyAsync<T> method? The implementation looks quite tricky. It is obvious that under no circumstances the method should consume more than one items from the channels. Also it should not leave behind fire-and-forget tasks, or let disposable resources undisposed. The method will be typically called in a loop, so it should also be reasonably efficient. It should have complexity not worse than O(n), where n in the number of the channels.

As an insight of where this method can be useful, you could take a look at the select statement of the Go language. From the tour:

The select statement lets a goroutine wait on multiple communication operations.

A select blocks until one of its cases can run, then it executes that case. It chooses one at random if multiple are ready.

select {
case msg1 := <-c1:
    fmt.Println("received", msg1)
case msg2 := <-c2:
    fmt.Println("received", msg2)
}

In the above example either a value will be taken from the channel c1 and assigned to the variable msg1, or a value will be taken from the channel c2 and assigned to the variable msg2. The Go select statement is not restricted to reading from channels. It can include multiple heterogeneous cases like writing to bounded channels, waiting for timers etc. Replicating the full functionality of the Go select statement is beyond the scope of this question.

Theodor Zoulias
  • 34,835
  • 7
  • 69
  • 104
  • what do you expect to return if all channels complete without returning the result? – alexm Jun 14 '22 at 23:31
  • @alexm the expected outcome in this case is that the `Task<(T Item, int Index)>` will complete in a [faulted](https://learn.microsoft.com/en-us/dotnet/api/system.threading.tasks.task.isfaulted) state, containing a [`ChannelClosedException`](https://learn.microsoft.com/en-us/dotnet/api/system.threading.channels.channelclosedexception) in its `Exception.InnerException` property. – Theodor Zoulias Jun 15 '22 at 00:19

3 Answers3

3

I came up with something like this:


public static async Task<(T Item, int Index)> TakeFromAnyAsync<T>(
    ChannelReader<T>[] channelReaders,
    CancellationToken cancellationToken = default)
{
    if (channelReaders == null)
    {
        throw new ArgumentNullException(nameof(channelReaders));
    }

    if (channelReaders.Length == 0)
    {
        throw new ArgumentException("The list cannot be empty.", nameof(channelReaders));
    }

    if (channelReaders.Length == 1)
    {
        return (await channelReaders[0].ReadAsync(cancellationToken), 0);
    }

    // First attempt to read an item synchronosuly 
    for (int i = 0; i < channelReaders.Length; ++i)
    {
        if (channelReaders[i].TryRead(out var item))
        {
            return (item, i);
        }
    }

    using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
    {

        var waitToReadTasks = channelReaders
                .Select(it => it.WaitToReadAsync(cts.Token).AsTask())
                .ToArray();

        var pendingTasks = new List<Task<bool>>(waitToReadTasks);

        while (pendingTasks.Count > 1)
        {
            var t = await Task.WhenAny(pendingTasks);

            if (t.IsCompletedSuccessfully && t.Result)
            {
                int index = Array.IndexOf(waitToReadTasks, t);
                var reader = channelReaders[index];

                // Attempt to read an item synchronosly
                if (reader.TryRead(out var item))
                {
                    if (pendingTasks.Count > 1)
                    {
                        // Cancel pending "wait to read" on the remaining readers
                        // then wait for the completion 
                        try
                        {
                            cts.Cancel();
                            await Task.WhenAll((IEnumerable<Task>)pendingTasks);
                        }
                        catch { }
                    }
                    return (item, index);
                }

                // Due to the race condition item is no longer available
                if (!reader.Completion.IsCompleted)
                {
                    // .. but the channel appears to be still open, so we retry
                    var waitToReadTask = reader.WaitToReadAsync(cts.Token).AsTask();
                    waitToReadTasks[index] = waitToReadTask;
                    pendingTasks.Add(waitToReadTask);
                }

            }

            // Remove all completed tasks that could not yield 
            pendingTasks.RemoveAll(tt => tt == t || 
                tt.IsCompletedSuccessfully && !tt.Result || 
                tt.IsFaulted || tt.IsCanceled);

        }

        int lastIndex = 0;
        if (pendingTasks.Count > 0)
        {
            lastIndex = Array.IndexOf(waitToReadTasks, pendingTasks[0]);
            await pendingTasks[0];
        }

        var lastItem = await channelReaders[lastIndex].ReadAsync(cancellationToken);
        return (lastItem, lastIndex);
    }
}

alexm
  • 6,854
  • 20
  • 24
  • You might need to .Dispose each readTask before returning. – David Browne - Microsoft Jun 15 '22 at 00:00
  • @DavidBrowne-Microsoft : Why would one need to dispose incomplete tasks? see https://devblogs.microsoft.com/pfxteam/do-i-need-to-dispose-of-tasks/ – alexm Jun 15 '22 at 00:03
  • @DavidBrowne-Microsoft: `Task.WaitAny` would require disposing incomplete tasks. Does the same rule apply to the _async_ `Task.WhenAny()` ? – alexm Jun 15 '22 at 00:13
  • Of course do what Stephen Toub says :) – David Browne - Microsoft Jun 15 '22 at 00:24
  • @alexm thanks for the answer. Regarding disposing the tasks, it's a known exception to the general rule, as it's evident from the [link](https://devblogs.microsoft.com/pfxteam/do-i-need-to-dispose-of-tasks/) that you posted. Tasks don't need to be disposed. [Here](https://stackoverflow.com/questions/3734280/is-it-considered-acceptable-to-not-call-dispose-on-a-tpl-task-object) is another link. – Theodor Zoulias Jun 15 '22 at 00:28
  • @DavidBrowne-Microsoft : I understand your sarcasm, but could you answer my second question? – alexm Jun 15 '22 at 00:32
  • I believe Task.WhenAny and Task.WaitAny would follow the same rules. – David Browne - Microsoft Jun 15 '22 at 00:32
  • Regarding the implementation, honestly I don't think that it's satisfactory. First and foremost is assumes that after a successful `WaitToReadAsync`, the `ReadAsync` will be also successful immediately. This is not guaranteed. This consumer might be racing with other consumers that also want to read data from the same channels. Secondly, it launches multiple concurrent `WaitToReadAsync` operations, and then abandons some of them in a fire-and-forget fashion. Lastly, it is an [O(n²)](https://stackoverflow.com/a/72276921/11178549) implementation. – Theodor Zoulias Jun 15 '22 at 00:40
  • @TheodorZoulias : perhaps you right about the race condition if there is more than one consumer listening channel. I disagree on n^2 part - the code is not polling the reader, merely waits for an event. – alexm Jun 15 '22 at 00:46
  • There is a `while` loop with potentially n iterations. Inside the loop is an `await Task.WhenAny` operation that internally attaches continuations to all the pending tasks. The total number of attached continuations might be less than n², because the pending tasks are reduced on each iteration, but it's surely not O(n). I am pretty sure that I could demonstrate it experimentally by testing your implementation with 10,000-20,000-30,000 channels, and measuring the overhead. How confident are you that the overhead will increase linearly? – Theodor Zoulias Jun 15 '22 at 00:57
  • @TheodorZoulias : I am only confident that one day i will be gone.. yet, I still dont see n^2 for the average case scenario. Similar example: the worst case performance for quicksort is proven to be n^2, yet everyone believes that it is n*log(n) on average – alexm Jun 15 '22 at 01:27
  • If you fix the other two problems, I might close my eyes to the debatable O(n²) issue and accept your answer. :-) – Theodor Zoulias Jun 15 '22 at 02:10
  • 1
    @TheodorZoulias : agreed, i need time to reflect on this .. thanks for an interesting challenge – alexm Jun 15 '22 at 02:14
  • Nice, thanks! Now it works almost flawlessly. There is still a [tiny problem](https://dotnetfiddle.net/Re7F1T) when the last channel completes with error, but it should be easy to fix. – Theodor Zoulias Jun 16 '22 at 02:52
1

Here is another approach. This implementation is conceptually the same with alexm's implementation, until the point where no channel has an item available immediately. Then it differs by avoiding the Task.WhenAny-in-a-loop pattern, and instead starts an asynchronous loop for each channel. All loops are racing to update a shared ValueTuple<T, int, bool> consumed variable, which is updated in a critical region, in order to prevent consuming an element from more than one channels.

/// <summary>
/// Takes an item asynchronously from any one of the specified channel readers.
/// </summary>
public static async Task<(T Item, int Index)> TakeFromAnyAsync<T>(
    ChannelReader<T>[] channelReaders,
    CancellationToken cancellationToken = default)
{
    ArgumentNullException.ThrowIfNull(channelReaders);
    if (channelReaders.Length == 0) throw new ArgumentException(
        $"The {nameof(channelReaders)} argument is a zero-length array.");
    foreach (var cr in channelReaders) if (cr is null) throw new ArgumentException(
        $"The {nameof(channelReaders)} argument contains at least one null element.");

    cancellationToken.ThrowIfCancellationRequested();

    // Fast path (at least one channel has an item available immediately)
    for (int i = 0; i < channelReaders.Length; i++)
        if (channelReaders[i].TryRead(out var item))
            return (item, i);

    // Slow path (all channels are currently empty)
    using var linkedCts = CancellationTokenSource
        .CreateLinkedTokenSource(cancellationToken);

    (T Item, int Index, bool HasValue) consumed = default;

    Task[] tasks = channelReaders.Select(async (channelReader, index) =>
    {
        while (true)
        {
            try
            {
                if (!await channelReader.WaitToReadAsync(linkedCts.Token)
                    .ConfigureAwait(false)) break;
            }
            // Only the exceptional cases below are normal.
            catch (OperationCanceledException)
                when (linkedCts.IsCancellationRequested) { break; }
            catch when (channelReader.Completion.IsCompleted
                && !channelReader.Completion.IsCompletedSuccessfully) { break; }

            // This channel has an item available now.
            lock (linkedCts)
            {
                if (consumed.HasValue)
                    return; // An item has already been consumed from another channel.

                if (!channelReader.TryRead(out var item))
                    continue; // We lost the race to consume the available item.

                consumed = (item, index, true); // We consumed an item successfully.
            }
            linkedCts.Cancel(); // Cancel the other tasks.
            return;
        }
    }).ToArray();

    // The tasks should never fail. If a task ever fails, we have a bug.
    try { foreach (var task in tasks) await task.ConfigureAwait(false); }
    catch (Exception ex) { Debug.Fail("Unexpected error", ex.ToString()); throw; }

    if (consumed.HasValue)
        return (consumed.Item, consumed.Index);
    cancellationToken.ThrowIfCancellationRequested();
    Debug.Assert(channelReaders.All(cr => cr.Completion.IsCompleted));
    throw new ChannelClosedException();
}

It should be noted that this solution, as well as alexm's solution, depends on canceling en masse all pending WaitToReadAsync operations when an element has been consumed. Unfortunately this triggers the infamous memory leak issue that affects .NET channels with idle producers. When any async operation on a channel is canceled, the canceled operation remains in memory, attached to the internal structures of the channel, until an element is written to the channel. This behavior has been triaged by Microsoft as by-design, although the possibility of improving it has not been ruled out. Interestingly this ambiguity makes this effect not eligible for documentation. So the only way to get informed about this is by chance, either by reading about it from unofficial sources, or by falling into it.

Theodor Zoulias
  • 34,835
  • 7
  • 69
  • 104
0

The problem is a lot easier if channels are used the way they're used in Go: Channel(Readers) as input, Channel(Readers) as output.

IEnumerable<ChannelReader<T>> sources=....;
await foreach(var msg in sources.TakeFromAny(token))
{
....
}

or

var merged=sources.TakeFromAny(token);
...
var msg=await merged.ReadAsync(token);

In this case, the input from all channel readers is copied to a single output channel. The return value of the method is the ChannelReader of this channel.

CopyToAsync helper

A CopyToAsync function can be used to copy messages from an input source to the output channel:

static async Task CopyToAsync<T>(
        this ChannelReader<T> input,
        ChannelWriter<T> output,
        CancellationToken token=default)
{
   while (await input.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
   {
         //Early exit if cancellation is requested
         while (!token.IsCancellationRequested &&  input.TryRead(out T? msg))
         {
             await output.WriteAsync(msg,token);
         }
   }
}

This code is similar to ReadAllAsync but exits immediately if cancellation is requested. ReadAllAsync will return all available items even if cancellation is requested. The methods used, including

WriteAsync doesn't throw if the channels are closed, which makes error handling a lot easier.

Error Handling and Railway-oriented programming

WaitToReadAsync does throw if the source faults but that exception and that exception will be propagated to the calling method and through Task.WhenAll to the output channel.

This can be a bit messy because it interrupts the entire pipeline. To avoid this, the error could be swallowed or logged inside CopyToAsync. An even better option would be to use Railway-oriented programming and wrap all messages in a Result<TMsg,TError> class eg :

static async Task CopyToAsync<Result<T,Exception>>(
        this ChannelReader<Result<T,Exception>> input,
        ChannelWriter<Result<T,Exception>> output,
        CancellationToken token=default)
{
   try
   {
     while (await input.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
     {
         //Early exit if cancellation is requested
         while (!token.IsCancellationRequested &&  input.TryRead(out T? msg))
         {
             var newMsg=Result.FromValue(msg);
             await output.WriteAsync(newMsg,token);
         }
     }
  }
  catch(Exception exc)
  {
    output.TryWrite(Result<T>.FromError(exc));
  }
}

TakeFromAsync

TakeFromAny (MergeAsync may be a better name) can be:

static ChannelReader<T> TakeFromAny(
        this IEnumerable<ChannelReader<T> inputs,
        CancellationToken token=default)
{
    var outChannel=Channel.CreateBounded<T>(1);

    var readers=inputs.Select(rd=>CopyToAsync(rd,outChannel,token));

    _ = Task.WhenAll(readers)
            .ContinueWith(t=>outChannel.TryComplete(t.Exception));
    return outChannel;
}

Using a bounded capacity of 1 ensures the backpressure behavior of downstream code doesn't change.

Adding a source index

This can be adjusted to emit the index of the source as well:

static async Task CopyToAsync<T>(
        this ChannelReader<T> input,int index,
        ChannelWriter<(T,int)> output,
        CancellationToken token=default)
{
  while (await input.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
  {
        while (!token.IsCancellationRequested &&  input.TryRead(out T? msg))
        {

            await output.WriteAsync((msg,index),token);
        }
  }
}

static ChannelReader<(T,int)> TakeFromAny(
        this IEnumerable<ChannelReader<T> inputs,
        CancellationToken token=default)
{
    var outChannel=Channel.CreateBounded<(int,T)>(1);

    var readers=inputs.Select((rd,idx)=>CopyToAsync(rd,idx,outChannel,token));

    _ = Task.WhenAll(readers)
            .ContinueWith(t=>outChannel.TryComplete(t.Exception));
    return outChannel;
}
Panagiotis Kanavos
  • 120,703
  • 13
  • 188
  • 236
  • I am not seeing the API that I have asked for in the question, which is a `TakeFromAnyAsync` that returns a `Task<(T Item, int Index)>`. So I can't verify the correctness of this solution experimentally, because it doesn't match the problem that I am trying to solve. From the looks of it, the requirement of consuming at most one element from any of the source channels is probably not met. – Theodor Zoulias Jun 28 '22 at 09:44
  • The consuming part is `ChannelReader.ReadAsync` which only returns a single item. Channels aren't BlockingCollection. Trying to make one work exactly like the other is far more difficult than using each the way it's meant – Panagiotis Kanavos Jun 28 '22 at 09:49
  • In that case `Task.WhenAll` will propagate that exception to the output channel, and cause the other `Copy` tasks to abort as well. A linked CancellationTokenSource can be used to cancel the other tasks early and gracefully in case of exceptions. If you don't want to propagate the exception you can swallow it inside `CopyToAsync` – Panagiotis Kanavos Jun 28 '22 at 09:58
  • Error handling becomes a *lot* easier if railway-oriented programming techniques are used, propagating `Result` objects along the pipeline. – Panagiotis Kanavos Jun 28 '22 at 10:15