2

I have a program (an ASGI server) that is structured roughly like this:

import asyncio
import contextvars

ctxvar = contextvars.ContextVar("ctx")


async def lifepsan():
    ctxvar.set("spam")


async def endpoint():
    assert ctxvar.get() == "spam"


async def main():
    ctx = contextvars.copy_context()
    task = asyncio.create_task(lifepsan())
    await task
    task = asyncio.create_task(endpoint())
    await task

asyncio.run(main())

Because the lifespan event / endpoints are run in tasks, they can't share contextvars. This is by design: tasks copy the context before executing, so lifespan can't set ctxvar properly. This is the desired behavior for endpoints, but I would like for execution to appear like this (from a user's perspective):

async def lifespan():
    ctxvar.set("spam")
    await endpoint()

In other words, the endpoints are executed in their own independent context, but within the context of the lifespan.

I tried to get this to work by using contextlib.copy_context():

import asyncio
import contextvars

ctxvar = contextvars.ContextVar("ctx")


async def lifepsan():
    ctxvar.set("spam")
    print("set")


async def endpoint():
    print("get")
    assert ctxvar.get() == "spam"


async def main():
    ctx = contextvars.copy_context()
    task = ctx.run(asyncio.create_task, lifepsan())
    await task
    endpoint_ctx = ctx.copy()
    task = endpoint_ctx.run(asyncio.create_task, endpoint())
    await task

asyncio.run(main())

As well as:

async def main():
    ctx = contextvars.copy_context()
    task = asyncio.create_task(ctx.run(lifespan))
    await task
    endpoint_ctx = ctx.copy()
    task = asyncio.create_task(endpoint_ctx.run(endpoint))
    await task

However it seems that contextvars.Context.run does not work this way (I guess the context is bound when the coroutine is created but not when it is executed).

Is there a simple way to achieve the desired behavior, without restructuring how the tasks are being created or such?

LoveToCode
  • 788
  • 6
  • 14
  • This is how context variables should work, if you want two coroutines to share the same value of variable you can try to use some key:value storage, e.g. ordinary dictionary. – Artiom Kozyrev Aug 04 '21 at 15:57
  • With full control over everything, that would make sense. The larger context for this is a situation is a library which we'll call *B*. *B* is called by another library *A*, which is the one doing the task scheduling and such (hence why I said "without restructuring how the tasks are being created). In turn, *B* calls user code. Users may want to use context variables, or even use a library which in turn is using context variables. The goal here was to enable reasonable behavior for users without requiring the users of *A* to modify their code. – LoveToCode Aug 04 '21 at 17:44

2 Answers2

2

Here's what I came up with, inspired by PEP 555 and asgiref:

from contextvars import Context, ContextVar, copy_context
from typing import Any


def _set_cvar(cvar: ContextVar, val: Any):
    cvar.set(val)


class CaptureContext:

    def __init__(self) -> None:
        self.context = Context()

    def __enter__(self) -> "CaptureContext":
        self._outer = copy_context()
        return self

    def sync(self):
        final = copy_context()
        for cvar in final:
            if cvar not in self._outer:
                # new contextvar set
                self.context.run(_set_cvar, cvar, final.get(cvar))
            else:
                final_val = final.get(cvar)
                if self._outer.get(cvar) != final_val:
                    # value changed
                    self.context.run(_set_cvar, cvar, final_val)

    def __exit__(self, *args: Any):
        self.sync()


def restore_context(context: Context) -> None:
    """Restore `context` to the current Context"""
    for cvar in context.keys():
        try:
            cvar.set(context.get(cvar))
        except LookupError:
            cvar.set(context.get(cvar))

Usage:

import asyncio
import contextvars

ctxvar = contextvars.ContextVar("ctx")


async def lifepsan(cap: CaptureContext):
    with cap:
        ctxvar.set("spam")


async def endpoint():
    assert ctxvar.get() == "spam"


async def main():
    cap = CaptureContext()
    await asyncio.create_task(lifepsan(cap))
    restore_context(cap.context)
    task = asyncio.create_task(endpoint())
    await task

asyncio.run(main())

The sync() method is provided in case the task is long-running and you need to capture the context before it finishes. A somewhat contrived example:

import asyncio
import contextvars

ctxvar = contextvars.ContextVar("ctx")


async def lifepsan(cap: CaptureContext, event: asyncio.Event):
    with cap:
        ctxvar.set("spam")
        cap.sync()
        event.set()
        await asyncio.sleep(float("inf"))


async def endpoint():
    assert ctxvar.get() == "spam"


async def main():
    cap = CaptureContext()
    event = asyncio.Event()
    asyncio.create_task(lifepsan(cap, event))
    await event.wait()
    restore_context(cap.context)
    task = asyncio.create_task(endpoint())
    await task

asyncio.run(main())

I think it would still be much nicer if contextvars.Context.run worked with coroutines.

LoveToCode
  • 788
  • 6
  • 14
1

This feature will be supported in Python 3.11: https://github.com/python/cpython/issues/91150

You will be able to write:

async def main():
    ctx = contextvars.copy_context()
    task = asyncio.create_task(lifepsan(), context=ctx)
    await task
    endpoint_ctx = ctx.copy()
    task = asyncio.create_task(endpoint(), context=endpoint_ctx)
    await task

In the meantime, in current Python versions you will need a backport of this feature. I can't think of a good one, but a bad one is here.

LoveToCode
  • 788
  • 6
  • 14