1

I got an async method working like an enhanced Task.WhenAll. It takes a bunch of tasks and returns when all are completed.

public async Task MyWhenAll(Task[] tasks) {
    ...
    await Something();
    ...

    // all tasks are completed
    if (someTasksFailed)
        throw ??
}

My question is how do I get the method to return a Task looking like the one returned from Task.WhenAll when one or more tasks has failed?

If I collect the exceptions and throw an AggregateException it will be wrapped in another AggregateException.

Edit: Full Example

async Task Main() {
    try {
        Task.WhenAll(Throw(1), Throw(2)).Wait();
    }
    catch (Exception ex) {
        ex.Dump();
    }

    try {
        MyWhenAll(Throw(1), Throw(2)).Wait();
    }
    catch (Exception ex) {
        ex.Dump();
    }
}

public async Task MyWhenAll(Task t1, Task t2) {
    await Task.Delay(TimeSpan.FromMilliseconds(100));
    try {
        await Task.WhenAll(t1, t2);
    }
    catch {
        throw new AggregateException(new[] { t1.Exception, t2.Exception });
    }
}
public async Task Throw(int id) {
    await Task.Delay(TimeSpan.FromMilliseconds(100));
    throw new InvalidOperationException("Inner" + id);
}

For Task.WhenAll the exception is AggregateException with 2 inner exceptions.

For MyWhenAll the exception is AggregateException with one inner AggregateException with 2 inner exceptions.

Edit: Why I am doing this

I often need to call paging API:s and want to limit number of simultaneous connections.

The actual method signatures are

public static async Task<TResult[]> AsParallelAsync<TResult>(this IEnumerable<Task<TResult>> source, int maxParallel)
public static async Task<TResult[]> AsParallelUntilAsync<TResult>(this IEnumerable<Task<TResult>> source, int maxParallel, Func<Task<TResult>, bool> predicate)

It means I can do paging like this

var pagedRecords = await Enumerable.Range(1, int.MaxValue)
                                   .Select(x => GetRecordsAsync(pageSize: 1000, pageNumber: x)
                                   .AsParallelUntilAsync(maxParallel: 5, x => x.Result.Count < 1000);
var records = pagedRecords.SelectMany(x => x).ToList();

It all works fine, the aggregate within aggregate is just a minor inconvenience.

Theodor Zoulias
  • 34,835
  • 7
  • 69
  • 104
adrianm
  • 14,468
  • 5
  • 55
  • 102
  • I think we need to see a [mcve]. If I try your code, I only get a single `AggregateException`: https://dotnetfiddle.net/u6EVSE – canton7 Apr 15 '19 at 11:15
  • I assume your question is answered here - https://stackoverflow.com/questions/25912899/will-awaiting-multiple-tasks-observe-more-than-the-first-exception – olegk Apr 15 '19 at 13:21
  • Why not propagate the exception thrown by WhenAll, rather than catching it and throwing your own? Or use WaitAll? – canton7 Apr 15 '19 at 14:55
  • @canton7, thank you for your reply. Propagate the aggregateexception from WhenAll makes no difference. It is still becomes an aggregate inside an aggregate. – adrianm Apr 15 '19 at 15:00
  • Somewhat related: [I want await to throw AggregateException, not just the first Exception](https://stackoverflow.com/questions/18314961/i-want-await-to-throw-aggregateexception-not-just-the-first-exception) – Theodor Zoulias Sep 21 '21 at 23:28

3 Answers3

3

async methods are designed to only every set at most a single exception on the returned task, not multiple.

This leaves you with two options, you can either not use an async method to start with, instead relying on other means of performing your method:

public Task MyWhenAll(Task t1, Task t2)
{
    return Task.Delay(TimeSpan.FromMilliseconds(100))
        .ContinueWith(_ => Task.WhenAll(t1, t2))
        .Unwrap();
}

If you have a more complex method that would be harder to write without using await, then you'll need to unwrap the nested aggregate exceptions, which is tedious, although not overly complex, to do:

    public static Task UnwrapAggregateException(this Task taskToUnwrap)
    {
        var tcs = new TaskCompletionSource<bool>();

        taskToUnwrap.ContinueWith(task =>
        {
            if (task.IsCanceled)
                tcs.SetCanceled();
            else if (task.IsFaulted)
            {
                if (task.Exception is AggregateException aggregateException)
                    tcs.SetException(Flatten(aggregateException));
                else
                    tcs.SetException(task.Exception);
            }
            else //successful
                tcs.SetResult(true);
        });

        IEnumerable<Exception> Flatten(AggregateException exception)
        {
            var stack = new Stack<AggregateException>();
            stack.Push(exception);
            while (stack.Any())
            {
                var next = stack.Pop();
                foreach (Exception inner in next.InnerExceptions)
                {
                    if (inner is AggregateException innerAggregate)
                        stack.Push(innerAggregate);
                    else
                        yield return inner;
                }
            }
        }

        return tcs.Task;
    }
Servy
  • 202,030
  • 26
  • 332
  • 449
  • I don't see any indication that he wants to flatten a hierarchy of AggregateExceptions? – canton7 Apr 15 '19 at 15:47
  • @canton7 The whole point of a general is to cover the general case, not write code that'll error or behave improperly if there's more than one aggregate exception nested in one other aggregate exception. – Servy Apr 15 '19 at 15:50
  • He's asking how to imitate the behaviour of `Task.WhenAll`, and unwrapping multiple levels of `AggregateException` isn't something that `Task.WhenAll` does. – canton7 Apr 15 '19 at 16:00
  • @canton7 That solution is attempting to undo the fact that additional layers of aggregate exceptions were added manually to code to get around the fact that `async` methods don't support providing multiple exceptions. Rather than writing to code to only work for a single example, it's coded to work regardless of how many aggregate exceptions are added and regardless of how they're aggregated. `Task.WhenAll` doesn't need to deal with that. If you think it's better to write a solution that simply breaks if there isn't exactly one aggregate exception in one aggregate exception, you can. – Servy Apr 15 '19 at 16:04
  • Thanks, the method isn't that complicated. I'll try your "manual" continuation and see if I can get it to work. – adrianm Apr 15 '19 at 16:14
  • `.ContinueWith(t => Task.WhenAll(t.Result)).Unwrap()` worked exactly as I wanted. Marked as answer. thanks again. – adrianm Apr 15 '19 at 20:20
0

Use a TaskCompletionSource.

The outermost exception is created by .Wait() or .Result - this is documented as wrapping the exception stored inside the Task inside an AggregateException (to preserve its stack trace - this was introduced before ExceptionDispatchInfo was created).

However, Task can actually contain many exceptions. When this is the case, .Wait() and .Result will throw an AggregateException which contains multiple InnerExceptions. You can access this functionality through TaskCompletionSource.SetException(IEnumerable<Exception> exceptions).

So you do not want to create your own AggregateException. Set multiple exceptions on the Task, and let .Wait() and .Result create that AggregateException for you.

So:

var tcs = new TaskCompletionSource<object>();
tcs.SetException(new[] { t1.Exception, t2.Exception });
return tcs.Task;

Of course, if you then call await MyWhenAll(..) or MyWhenAll(..).GetAwaiter().GetResult(), then it will only throw the first exception. This matches the behaviour of Task.WhenAll.

This means you need to pass tcs.Task up as your method's return value, which means your method can't be async. You end up doing ugly things like this (adjusting the sample code from your question):

public static Task MyWhenAll(Task t1, Task t2)
{
    var tcs = new TaskCompletionSource<object>();
    var _ = Impl();
    return tcs.Task;

    async Task Impl()
    {
        await Task.Delay(10);
        try
        {
            await Task.WhenAll(t1, t2);
            tcs.SetResult(null);
        }
        catch
        {
            tcs.SetException(new[] { t1.Exception, t2.Exception });
        }
    }
}

At this point, though, I'd start to query why you're trying to do this, and why you can't use the Task returned from Task.WhenAll directly.

canton7
  • 37,633
  • 3
  • 64
  • 77
  • By awaiting the tcs task you defeat the entire purpose of doing this, as it'll only throw the first exception if there are multiple. – Servy Apr 15 '19 at 15:24
  • This doesn't work correctly if only some of the tasks throw. – Servy Apr 15 '19 at 15:40
  • @Servy I know. I copied his code, and adjusted it. I've no idea what he's actually trying to write -- I assume he knows that this bit of his code is incorrect, and was only for illustration. – canton7 Apr 15 '19 at 15:42
  • Thanks, I was thinking about tcs but couldn't really work out how to use it. I'll look closer at your example as see if it gets better this time. – adrianm Apr 15 '19 at 16:29
0

I deleted my previous answer, because I found a simpler solution. This solution does not involve the pesky ContinueWith method or the TaskCompletionSource type. The idea is to return a nested Task<Task> from a local function, and Unwrap() it from the outer container function. Here is a basic outline of this idea:

public Task<T[]> GetAllAsync<T>()
{
    return LocalAsyncFunction().Unwrap();

    async Task<Task<T[]>> LocalAsyncFunction()
    {
        var tasks = new List<Task<T>>();
        // ...
        await SomethingAsync();
        // ...
        Task<T[]> whenAll = Task.WhenAll(tasks);
        return whenAll;
    }
}

The GetAllAsync method is not async. It delegates all the work to the LocalAsyncFunction, which is async, and then Unwraps the resulting nested task and returns it. The unwrapped task contains in its .Exception.InnerExceptions property all the exceptions of the tasks, because it is just a facade of the internal Task.WhenAll task.

Let's demonstrate a more practical realization of this idea. The AsParallelUntilAsync method below enumerates lazily the source sequence and projects the items it contains to Task<TResult>s, until an item satisfies the predicate. It also limits the concurrency of the asynchronous operations. The difficulty is that enumerating the IEnumerable<TSource> could throw an exception too. The correct behavior in this case is to await all the running tasks before propagating the enumeration error, and return an AggregateException that contains both the enumeration error, and all the task errors that may have occurred in the meantime. Here is how it can be done:

public static Task<TResult[]> AsParallelUntilAsync<TSource, TResult>(
    this IEnumerable<TSource> source, Func<TSource, Task<TResult>> action,
    Func<TSource, bool> predicate, int maxConcurrency)
{
    return Implementation().Unwrap();

    async Task<Task<TResult[]>> Implementation()
    {
        var tasks = new List<Task<TResult>>();

        async Task<TResult> EnumerateAsync()
        {
            var semaphore = new SemaphoreSlim(maxConcurrency, maxConcurrency);
            using var enumerator = source.GetEnumerator();
            while (true)
            {
                await semaphore.WaitAsync();
                if (!enumerator.MoveNext()) break;
                var item = enumerator.Current;
                if (predicate(item)) break;

                async Task<TResult> RunAndRelease(TSource item)
                {
                    try { return await action(item); }
                    finally { semaphore.Release(); }
                }

                tasks.Add(RunAndRelease(item));
            }
            return default; // A dummy value that will never be returned
        }

        Task<TResult> enumerateTask = EnumerateAsync();

        try
        {
            await enumerateTask; // Make sure that the enumeration succeeded
            Task<TResult[]> whenAll = Task.WhenAll(tasks);
            await whenAll; // Make sure that all the tasks succeeded
            return whenAll;
        }
        catch
        {
            // Return a faulted task that contains ALL the errors!
            return Task.WhenAll(tasks.Prepend(enumerateTask));
        }
    }
}
Theodor Zoulias
  • 34,835
  • 7
  • 69
  • 104