0

I've been working with Sync SQLA Session on my fastAPI app. now we're moving to asynchronous db calls. the piece of code that help us to rollback the db interactions after each unit tests is not longer working

My code snipped is based from https://stackoverflow.com/a/67348153/6495199 and I did updat it based on this suggestion https://github.com/sqlalchemy/sqlalchemy/issues/5811#issuecomment-756269881

test_engine = create_async_engine(get_db_url(), poolclass=StaticPool)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=test_engine, class_=AsyncSession)

@pytest.fixture()
async def db_session():
    connection = await test_engine.connect()
    
    transaction = await connection.begin()
    await connection.begin_nested()    
    session = TestingSessionLocal(bind=connection)
    
    @sqlalchemy.event.listens_for(session.sync_session, "after_transaction_end")
    def end_savepoint(session, transaction):
        if connection.closed:
            return
        if not connection.in_nested_transaction():
            connection.sync_connection.begin_nested()
    
    yield session
    # Rollback the overall transaction, restoring the state before the test ran.
    await session.close()
    await transaction.rollback()
    await connection.close()

@pytest.fixture()
def client(db_session):
    # import ipdb; ipdb.set_trace()
    # async def override_get_db():
    def override_get_db():
        yield db_session

    app.dependency_overrides[get_db] = override_get_db
    yield TestClient(app)
    del app.dependency_overrides[get_db]

def test_get_by_id(client):
    response = client.get("/users/1")
    assert response.status_code == 404

At some point of my code I'm running this function But i got an error

async def find_by_id(cls, session: AsyncSession, pk: int):
    return await session.get(cls, pk)

# AttributeError: 'async_generator' object has no attribute 'get'
Carlos Rojas
  • 334
  • 5
  • 13

1 Answers1

0

Related to this answer with a similar problem, your fixture is called db_session, but you are not passing it as an argument to the test function find_by_id (you are instead passing session, which results in yielding async generator instead of your db_session).

When pytest runs a test, it looks for fixtures that have the same names as the parameters of the function.

The fix is simple, replace your test function's signature to:

async def find_by_id(cls, db_session: AsyncSession, pk: int):
    ...
amoralesc
  • 34
  • 4