4

I have an asynchronous function which connects to a database. Currently my users do:

conn = await connect(uri, other_params)

I want to continue to support this, but want to additionally allow connect() to be used as a context manager:

async with connect(uri, other_params) as conn:
     pass

The difference between these two scenarios is that in the first case connect is awaited, and in the second case it is not.

Is it possible to tell, within the body of connect, if the coroutine was awaited or not?

My current effort at this on repl.it.

LondonRob
  • 73,083
  • 37
  • 144
  • 201

2 Answers2

2

Here's code that passes tests you provided:

import asyncio
import pytest
from functools import wraps


def connection_context_manager(func):
  @wraps(func)
  def wrapper(*args, **kwargs):

    class Wrapper:
        def __init__(self):
          self._conn = None

        async def __aenter__(self):
            self._conn = await func(*args, **kwargs)
            return self._conn

        async def __aexit__(self, *_):
          await self._conn.close()

        def __await__(self):
            return func(*args, **kwargs).__await__()  # https://stackoverflow.com/a/33420721/1113207
    return Wrapper()

  return wrapper

Note how three magic methods allows us to make object awaitable and async context manager at the same time.

Feel free to ask questions if you have any.

LondonRob
  • 73,083
  • 37
  • 144
  • 201
Mikhail Gerasimov
  • 36,989
  • 16
  • 116
  • 159
  • 3
    Great answer! I'd just recommend avoiding creating a new class every time the decorator is invoked because classes are comparatively heavyweight and slow to construct. And it's quite easy to avoid it - just define `Wrapper` at top-level, pass `func` to its constructor, and use `self._func` where needed. – user4815162342 Oct 26 '19 at 06:37
  • 1
    You and @user4815162342 are my heroes. This is good stuff. I actually didn't know there was an `__await__` magic. Thanks both. – LondonRob Oct 26 '19 at 16:28
  • 1
    @user4815162342's answer is shown in detail [on my Github](https://github.com/roblevy/async-and-sync) – LondonRob Oct 26 '19 at 17:16
0

I think there is a structural issue in your example.

So first your example needs to await the __call__:

@pytest.mark.asyncio
async def test_connect_with_context_manager():
    async with await connect("with context uri") as no_context:
        # Now it's open
        assert no_context.open

    # Now it's closed
    assert not no_context.open

But here the problem is, the result of await connect("with context uri") is a Connection, which doesn't even have __aexit__ method.

So I believe you should change the structure totally, adding connect method to the Connection to actually make a connection, and in MagicConnection, delegating every method of the Connection.

Sraw
  • 18,892
  • 11
  • 54
  • 87