I'm trying to use use a fixed-size queue for a variable (and larger) number of tasks with asyncio. Everything works fine when there aren't task exceptions, but when there are task exceptions I'd like to fail early, on the first exception.
Right now, all exceptions are being skipped silently with the code below. I know I can capture exceptions in the same way I'm capturing valid task results, and then raise them later, but I want to raise on the first exception - not at the end.
What am I missing?
import asyncio
import threading
from typing import Awaitable, Callable, List
import aiohttp
import aiostream
def async_wrap_iter(it):
"""Wrap blocking iterator into an asynchronous one.
Source: https://stackoverflow.com/a/62297994/7619676
"""
loop = asyncio.get_event_loop()
q = asyncio.Queue(1)
exception = None
_END = object()
async def yield_queue_items():
while True:
next_item = await q.get()
if next_item is _END:
break
yield next_item
if exception is not None:
# the iterator has raised, propagate the exception
raise exception
def iter_to_queue():
nonlocal exception
try:
for item in it:
# This runs outside the event loop thread, so we
# must use thread-safe API to talk to the queue.
asyncio.run_coroutine_threadsafe(q.put(item), loop).result()
except Exception as e:
exception = e
finally:
asyncio.run_coroutine_threadsafe(q.put(_END), loop).result()
threading.Thread(target=iter_to_queue).start()
return yield_queue_items()
async def main(
rows,
func: Callable[[List], Awaitable[None]],
batch_size: int = 20,
max_workers: int = 50,
) -> List:
"""Adapted from https://stackoverflow.com/a/62404509/7619676"""
queue = asyncio.Queue(max_workers)
results = []
async def worker(func, queue, results):
while True:
batch = await queue.get()
try:
results.append(await func(batch))
except Exception as e:
raise e
finally:
queue.task_done()
# create `max_workers` workers and feed them tasks.
workers = [
asyncio.create_task(worker(func, queue, results))
for _ in range(max_workers)
]
# Feed the database rows to the workers.
# The fixed-capacity of the queue ensures that we never hold all rows in memory at the same time.
# When the queue reaches full capacity, this will block until a worker dequeues an item.
rows = async_wrap_iter(rows)
async with aiostream.stream.chunks(rows, batch_size).stream() as chunks:
async for batch in chunks:
await queue.put(batch) # enqueue a batch of `batch_size` rows
await queue.join()
for worker in workers:
worker.cancel()
return results
async def func_that_errors_on_evens(batch):
i = batch[0]
print(i)
if i % 2 == 0:
raise Exception("fake")
return i
rows = [1, 2, 3, 4]
asyncio.run(main(rows=rows, func=func_that_errors_on_evens, batch_size=1, max_workers=2))
Based on https://stackoverflow.com/a/59629996/7619676, instead of await queue.join()
I tried the following:
done, _ = await asyncio.wait(
[queue.join(), *workers], return_when=asyncio.FIRST_EXCEPTION
) # alternatively, use asyncio.ALL_COMPLETED to raise "late"
consumers_raised = set(done) & set(workers)
if consumers_raised:
await consumers_raised.pop() # propagate the exception
While that solution works if there IS an exception, it seems to hang forever if there's NOT an exception.