1

I'm trying to use use a fixed-size queue for a variable (and larger) number of tasks with asyncio. Everything works fine when there aren't task exceptions, but when there are task exceptions I'd like to fail early, on the first exception.

Right now, all exceptions are being skipped silently with the code below. I know I can capture exceptions in the same way I'm capturing valid task results, and then raise them later, but I want to raise on the first exception - not at the end.

What am I missing?

import asyncio
import threading
from typing import Awaitable, Callable, List

import aiohttp
import aiostream


def async_wrap_iter(it):
    """Wrap blocking iterator into an asynchronous one.

    Source: https://stackoverflow.com/a/62297994/7619676
    """
    loop = asyncio.get_event_loop()
    q = asyncio.Queue(1)
    exception = None
    _END = object()

    async def yield_queue_items():
        while True:
            next_item = await q.get()
            if next_item is _END:
                break
            yield next_item
        if exception is not None:
            # the iterator has raised, propagate the exception
            raise exception

    def iter_to_queue():
        nonlocal exception
        try:
            for item in it:
                # This runs outside the event loop thread, so we
                # must use thread-safe API to talk to the queue.
                asyncio.run_coroutine_threadsafe(q.put(item), loop).result()
        except Exception as e:
            exception = e
        finally:
            asyncio.run_coroutine_threadsafe(q.put(_END), loop).result()

    threading.Thread(target=iter_to_queue).start()
    return yield_queue_items()


async def main(
    rows,
    func: Callable[[List], Awaitable[None]],
    batch_size: int = 20,
    max_workers: int = 50,
) -> List:
    """Adapted from https://stackoverflow.com/a/62404509/7619676"""
    queue = asyncio.Queue(max_workers)
    results = []

    async def worker(func, queue, results):
        while True:
            batch = await queue.get()
            try:
                results.append(await func(batch))
            except Exception as e:
                raise e
            finally:
                queue.task_done()

    # create `max_workers` workers and feed them tasks.
    workers = [
        asyncio.create_task(worker(func, queue, results))
        for _ in range(max_workers)
    ]

    # Feed the database rows to the workers.
    # The fixed-capacity of the queue ensures that we never hold all rows in memory at the same time.
    # When the queue reaches full capacity, this will block until a worker dequeues an item.
    rows = async_wrap_iter(rows)
    async with aiostream.stream.chunks(rows, batch_size).stream() as chunks:
        async for batch in chunks:
            await queue.put(batch)  # enqueue a batch of `batch_size` rows

    await queue.join()

    for worker in workers:
        worker.cancel()

    return results


async def func_that_errors_on_evens(batch):
    i = batch[0]
    print(i)
    if i % 2 == 0:
        raise Exception("fake")
    return i

rows = [1, 2, 3, 4]
asyncio.run(main(rows=rows, func=func_that_errors_on_evens, batch_size=1, max_workers=2))

Based on https://stackoverflow.com/a/59629996/7619676, instead of await queue.join() I tried the following:

done, _ = await asyncio.wait(
    [queue.join(), *workers], return_when=asyncio.FIRST_EXCEPTION
)  # alternatively, use asyncio.ALL_COMPLETED to raise "late"
consumers_raised = set(done) & set(workers)
if consumers_raised:
    await consumers_raised.pop()  # propagate the exception

While that solution works if there IS an exception, it seems to hang forever if there's NOT an exception.

ZaxR
  • 4,896
  • 4
  • 23
  • 42
  • As far as I can see, the only function of the Queue is to block when 50 database items are in memory at the same time. Have you considered using a asyncio.Semaphore(50) to accomplish the same thing? Then you could get rid of the Queue and perhaps the second thread, and capture each return value as it becomes available instead of collecting them all in a List. In that case I think you can catch an Exception fairly easily. – Paul Cornelius Jul 17 '22 at 02:00
  • The real use case needs to limit memory usage, which I don’t think I can accomplish without the queue, right? – ZaxR Jul 17 '22 at 02:29
  • Why do you need a queue? A queue has nothing directly to do with memory usage. You are using the queue to prevent the creation of new tasks when you have 50 of them already running, correct? All I'm saying is that there are other, much simpler ways to do that. And the use of the queue is what makes it difficult for you to access the results (and therefore the exceptions) one at a time, since there is no obvious way to return the results to your main task until you have collected them all. – Paul Cornelius Jul 17 '22 at 03:00
  • Correct re: limiting memory by preventing loading data into memory and creating tasks after 50 chunks are loaded/running tasks. What would you suggest as an alternative mechanism? The trouble I hit when trying to think through a semaphore solution is how to not have to create all the tasks up front – ZaxR Jul 17 '22 at 05:18

1 Answers1

0

Based on @PaulCornelius's comments, I refactored to use a Semaphore instead of a Queue. The trickiest part of propagating exceptions early still remained, because I didn't want to create all tasks upfront (for memory efficiency), which seemingly ruled out asyncio.gather. I worked around this issue by writing a custom, non-waiting function to check completed tasks for errors:

def propagate_exceptions(tasks: List[asyncio.Task]) -> None:
    """Raises the first exception in a list of tasks, determined by task order.

    Note that if multiple tasks have errored,
    the error raised will be that of the earliest created task,
    not necessarily the task to have first errored.

    """
    for task in tasks:
        if task.done() and task.exception():
            raise task.exception()

And then adding a final wait asyncio.gather at the end:

async def main(
    rows,
    func: Callable[[List], Awaitable[None]],
    batch_size: int = 20,
    max_workers: int = 50,
) -> List:
    semaphore = asyncio.Semaphore(max_workers)
    tasks = set()

    rows = async_wrap_iter(rows)

    # The fixed-capacity of the Semaphore ensures that we never hold all rows in memory at the same time.
    # When the sem reaches full capacity, this will block until a task completes.
    async with aiostream.stream.chunks(rows, batch_size).stream() as chunks:
        async for batch in chunks:
            await semaphore.acquire()
            task = asyncio.create_task(api_func(batch, session))
            task.add_done_callback(lambda task: semaphore.release())
            tasks.add(task)
            # Propagate and raise exceptions error early.
            # Note that this doesn't await any task completion,
            # so exceptions may be raised in a later loop or in the final asyncio.gather catch
            propagate_exceptions(tasks)

    # This will wait for all tasks to complete, propagating any final exceptions.
    # Needed to catch any final errors for tasks that weren't yet finished when generator loop ended
    # It is also responsible for returning task results if no exceptions have occurred.
    return await asyncio.gather(*tasks, return_exceptions=False)
ZaxR
  • 4,896
  • 4
  • 23
  • 42