67

I am trying to write a bulk upsert in python using the SQLAlchemy module (not in SQL!).

I am getting the following error on a SQLAlchemy add:

sqlalchemy.exc.IntegrityError: (IntegrityError) duplicate key value violates unique constraint "posts_pkey"
DETAIL:  Key (id)=(TEST1234) already exists.

I have a table called posts with a primary key on the id column.

In this example, I already have a row in the db with id=TEST1234. When I attempt to db.session.add() a new posts object with the id set to TEST1234, I get the error above. I was under the impression that if the primary key already exists, the record would get updated.

How can I upsert with Flask-SQLAlchemy based on primary key alone? Is there a simple solution?

If there is not, I can always check for and delete any record with a matching id, and then insert the new record, but that seems expensive for my situation, where I do not expect many updates.

mgoldwasser
  • 14,558
  • 15
  • 79
  • 103
  • 6
    How is that duplicate if original question doesn't mention SQLAlchemy? – techkuz Feb 25 '19 at 15:23
  • Could you consider accepting [exhuma's answer](https://stackoverflow.com/a/63189754/652669)? It leverages PosgreSQL's `INSERT … ON CONFLICT DO UPDATE` feature and works great. – GG. Nov 10 '21 at 23:46

6 Answers6

54

There is an upsert-esque operation in SQLAlchemy:

db.session.merge()

After I found this command, I was able to perform upserts, but it is worth mentioning that this operation is slow for a bulk "upsert".

The alternative is to get a list of the primary keys you would like to upsert, and query the database for any matching ids:

# Imagine that post1, post5, and post1000 are posts objects with ids 1, 5 and 1000 respectively
# The goal is to "upsert" these posts.
# we initialize a dict which maps id to the post object

my_new_posts = {1: post1, 5: post5, 1000: post1000} 

for each in posts.query.filter(posts.id.in_(my_new_posts.keys())).all():
    # Only merge those posts which already exist in the database
    db.session.merge(my_new_posts.pop(each.id))

# Only add those posts which did not exist in the database 
db.session.add_all(my_new_posts.values())

# Now we commit our modifications (merges) and inserts (adds) to the database!
db.session.commit()
mgoldwasser
  • 14,558
  • 15
  • 79
  • 103
38

You can leverage the on_conflict_do_update variant. A simple example would be the following:

from sqlalchemy.dialects.postgresql import insert

class Post(Base):
    """
    A simple class for demonstration
    """

    id = Column(Integer, primary_key=True)
    title = Column(Unicode)

# Prepare all the values that should be "upserted" to the DB
values = [
    {"id": 1, "title": "mytitle 1"},
    {"id": 2, "title": "mytitle 2"},
    {"id": 3, "title": "mytitle 3"},
    {"id": 4, "title": "mytitle 4"},
]

stmt = insert(Post).values(values)
stmt = stmt.on_conflict_do_update(
    # Let's use the constraint name which was visible in the original posts error msg
    constraint="post_pkey",

    # The columns that should be updated on conflict
    set_={
        "title": stmt.excluded.title
    }
)
session.execute(stmt)

See the Postgres docs for more details about ON CONFLICT DO UPDATE.

See the SQLAlchemy docs for more details about on_conflict_do_update.

Side-Note on duplicated column names

The above code uses the column names as dict keys both in the values list and the argument to set_. If the column-name is changed in the class-definition this needs to be changed everywhere or it will break. This can be avoided by accessing the column definitions, making the code a bit uglier, but more robust:

coldefs = Post.__table__.c

values = [
    {coldefs.id.name: 1, coldefs.title.name: "mytitlte 1"},
    ...
]

stmt = stmt.on_conflict_do_update(
    ...
    set_={
        coldefs.title.name: stmt.excluded.title
        ...
    }
)
GG.
  • 21,083
  • 14
  • 84
  • 130
exhuma
  • 20,071
  • 12
  • 90
  • 123
  • My `constraint="post_pkey"` code is failing because sqlalchemy can't find the unique constraint which I created in raw sql `CREATE UNIQUE INDEX post_pkey...` and then loaded into sqlalchemy with `metadata.reflect(eng, only="my_table")` after which I received a warning `base.py:3515: SAWarning: Skipped unsupported reflection of expression-based index post_pkey` Any tips for how to fix? – user1071182 Oct 30 '20 at 04:45
  • @user1071182 I think it would be better to post this as a separate question. It would allow you to add more detail. Without seeing the full `CREATE INDEX` statement it is hard to guess what's going wrong here. I can't promise anything though because I have not yet worked with partial indices with SQLAlchemy. But maybe someone else might have a solution. – exhuma Oct 30 '20 at 11:46
  • @exhuma @GG. Thanks for this solution, but I'm facing an issue with this. When I run this, I get an error saying `The 'default' dialect with current database version settings does not support in-place multirow inserts.`. It works fine when I upsert a single value but gives this error on multiple values. Any idea how to fix this? – Nikhil Arora Sep 15 '22 at 16:37
  • @NikhilArora The error hints at an incompatibility between your metadata setup and your database version. Make sure you are using SQLAlchemy at a recent version, and also make sure that you are using PostgreSQL as backend. If both are true, check the PostgreSQL driver you are using. The above was tested on `psycopg2`. – exhuma Sep 22 '22 at 11:28
5

An alternative approach using compilation extension (https://docs.sqlalchemy.org/en/13/core/compiler.html):

from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert

@compiles(Insert)
def compile_upsert(insert_stmt, compiler, **kwargs):
    """
    converts every SQL insert to an upsert  i.e;
    INSERT INTO test (foo, bar) VALUES (1, 'a')
    becomes:
    INSERT INTO test (foo, bar) VALUES (1, 'a') ON CONFLICT(foo) DO UPDATE SET (bar = EXCLUDED.bar)
    (assuming foo is a primary key)
    :param insert_stmt: Original insert statement
    :param compiler: SQL Compiler
    :param kwargs: optional arguments
    :return: upsert statement
    """
    pk = insert_stmt.table.primary_key
    insert = compiler.visit_insert(insert_stmt, **kwargs)
    ondup = f'ON CONFLICT ({",".join(c.name for c in pk)}) DO UPDATE SET'
    updates = ', '.join(f"{c.name}=EXCLUDED.{c.name}" for c in insert_stmt.table.columns)
    upsert = ' '.join((insert, ondup, updates))
    return upsert

This should ensure that all insert statements behave as upserts. This implementation is in Postgres dialect, but it should be fairly easy to modify for MySQL dialect.

danielcahall
  • 2,672
  • 8
  • 14
  • 1
    Getting this error when using that snippet: `sqlalchemy.exc.ProgrammingError: (psycopg2.errors.SyntaxError) syntax error at or near ")" LINE 1: ...on) VALUES ('US^WYOMING^ALBANY', '') ON CONFLICT () DO UPDAT...` – Mark Coletti Jun 29 '20 at 16:25
  • Ah nice catch! If you don’t have a primary key in your table, this wouldn’t work. Let me add a fix. – danielcahall Jul 02 '20 at 21:36
  • actually, I'm not sure why you would need this if you didn't have a primary key - could you elaborate on the problem? – danielcahall Jul 02 '20 at 22:12
  • 2
    Converting *all* inserts into upserts is risky. Sometimes you *need* to get integrity errors for data consistency and to avoid accidental overwrites. I would only use this solution if you are 120% aware of all the implications this has! – exhuma Jul 31 '20 at 10:07
  • 1
    Note that if you're using Postgres, it's better to use the [built-in ON CONFLICT feature](https://docs.sqlalchemy.org/en/14/orm/persistence_techniques.html#using-postgresql-on-conflict-with-returning-to-return-upserted-orm-objects). – Ramon Dias May 15 '22 at 05:58
2

I started looking at this and I think I've found a pretty efficient way to do upserts in sqlalchemy with a mix of bulk_insert_mappings and bulk_update_mappings instead of merge.

import time
import sqlite3

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from contextlib import contextmanager


engine = None
Session = sessionmaker()
Base = declarative_base()


def creat_new_database(db_name="sqlite:///bulk_upsert_sqlalchemy.db"):
    global engine
    engine = create_engine(db_name, echo=False)
    local_session = scoped_session(Session)
    local_session.remove()
    local_session.configure(bind=engine, autoflush=False, expire_on_commit=False)
    Base.metadata.drop_all(engine)
    Base.metadata.create_all(engine)


@contextmanager
def db_session():
    local_session = scoped_session(Session)
    session = local_session()

    session.expire_on_commit = False

    try:
        yield session
    except BaseException:
        session.rollback()
        raise
    finally:
        session.close()


class Customer(Base):
    __tablename__ = "customer"
    id = Column(Integer, primary_key=True)
    name = Column(String(255))


def bulk_upsert_mappings(customers):

    entries_to_update = []
    entries_to_put = []
    with db_session() as sess:
        t0 = time.time()

        # Find all customers that needs to be updated and build mappings
        for each in (
            sess.query(Customer.id).filter(Customer.id.in_(customers.keys())).all()
        ):
            customer = customers.pop(each.id)
            entries_to_update.append({"id": customer["id"], "name": customer["name"]})

        # Bulk mappings for everything that needs to be inserted
        for customer in customers.values():
            entries_to_put.append({"id": customer["id"], "name": customer["name"]})

        sess.bulk_insert_mappings(Customer, entries_to_put)
        sess.bulk_update_mappings(Customer, entries_to_update)
        sess.commit()

    print(
        "Total time for upsert with MAPPING update "
        + str(len(customers))
        + " records "
        + str(time.time() - t0)
        + " sec"
        + " inserted : "
        + str(len(entries_to_put))
        + " - updated : "
        + str(len(entries_to_update))
    )


def bulk_upsert_merge(customers):

    entries_to_update = 0
    entries_to_put = []
    with db_session() as sess:
        t0 = time.time()

        # Find all customers that needs to be updated and merge
        for each in (
            sess.query(Customer.id).filter(Customer.id.in_(customers.keys())).all()
        ):
            values = customers.pop(each.id)
            sess.merge(Customer(id=values["id"], name=values["name"]))
            entries_to_update += 1

        # Bulk mappings for everything that needs to be inserted
        for customer in customers.values():
            entries_to_put.append({"id": customer["id"], "name": customer["name"]})

        sess.bulk_insert_mappings(Customer, entries_to_put)
        sess.commit()

    print(
        "Total time for upsert with MERGE update "
        + str(len(customers))
        + " records "
        + str(time.time() - t0)
        + " sec"
        + " inserted : "
        + str(len(entries_to_put))
        + " - updated : "
        + str(entries_to_update)
    )


if __name__ == "__main__":

    batch_size = 10000

    # Only inserts
    customers_insert = {
        i: {"id": i, "name": "customer_" + str(i)} for i in range(batch_size)
    }

    # 50/50 inserts update
    customers_upsert = {
        i: {"id": i, "name": "customer_2_" + str(i)}
        for i in range(int(batch_size / 2), batch_size + int(batch_size / 2))
    }

    creat_new_database()
    bulk_upsert_mappings(customers_insert.copy())
    bulk_upsert_mappings(customers_upsert.copy())
    bulk_upsert_mappings(customers_insert.copy())

    creat_new_database()
    bulk_upsert_merge(customers_insert.copy())
    bulk_upsert_merge(customers_upsert.copy())
    bulk_upsert_merge(customers_insert.copy())

The results for the benchmark:

Total time for upsert with MAPPING: 0.17138004302978516 sec inserted : 10000 - updated : 0
Total time for upsert with MAPPING: 0.22074174880981445 sec inserted : 5000 - updated : 5000
Total time for upsert with MAPPING: 0.22307634353637695 sec inserted : 0 - updated : 10000
Total time for upsert with MERGE: 0.1724097728729248 sec inserted : 10000 - updated : 0
Total time for upsert with MERGE: 7.852903842926025 sec inserted : 5000 - updated : 5000
Total time for upsert with MERGE: 15.11970829963684 sec inserted : 0 - updated : 10000
  • 1
    Your answer is indeed interesting, but it's good to be aware that there's some drawbacks. As says the [documentation](https://docs.sqlalchemy.org/en/14/orm/session_api.html#sqlalchemy.orm.Session.bulk_insert_mappings), those bulk methods are _slowly being moved into legacy status_ for performance and safety reasons. Check the **warning section** beforehand. Also it says that it's not worth it if you need to bulk upsert on tables with relations. Check the `return_defaults` parameter on the **parameters section** of the same link. – Ramon Dias May 15 '22 at 05:44
1

I know this is kind of late, but I have built on the answer given by @Emil Wåreusand turned it into a function that can be used on any model (table),

def upsert_data(self, entries, model, key):
    entries_to_update = []
    entries_to_insert = []
    
    # get all entries to be updated
    for each in session.query(model).filter(getattr(model, key).in_(entries.keys())).all():
        entry = entries.pop(str(getattr(each, key)))
        entries_to_update.append(entry)
        
    # get all entries to be inserted
    for entry in entries.values():
        entries_to_insert.append(entry)

    session.bulk_insert_mappings(model, entries_to_insert)
    session.bulk_update_mappings(model, entries_to_update)

    session.commit()

entries should be a dictionary, with the primary key values as the keys, and the values should be mappings (mappings of the values against the columns of the database).

model is the ORM model that you want to upsert to.

key is the primary key of the table.

You can even use this function to get the model for the table you want to insert to from a string,

def get_table(self, table_name):
    for c in self.base._decl_class_registry.values():
        if hasattr(c, '__tablename__') and c.__tablename__ == table_name:
            return c

Using this, you can just pass the name of the table as a string to the upsert_data function,

def upsert_data(self, entries, table, key):
    model = get_table(table)
    entries_to_update = []
    entries_to_insert = []
    
    # get all entries to be updated
    for each in session.query(model).filter(getattr(model, key).in_(entries.keys())).all():
        entry = entries.pop(str(getattr(each, key)))
        entries_to_update.append(entry)
        
    # get all entries to be inserted
    for entry in entries.values():
        entries_to_insert.append(entry)

    session.bulk_insert_mappings(model, entries_to_insert)
    session.bulk_update_mappings(model, entries_to_update)

    session.commit()
Minura Punchihewa
  • 1,498
  • 1
  • 12
  • 35
0

This is not the safest method, but it is very simple and very fast. I was just trying to selectively overwrite a portion of a table. I deleted the known rows that I knew would conflict and then I appended the new rows from a pandas dataframe. Your pandas dataframe column names will need to match your sql table column names.

eng = create_engine('postgresql://...')
conn = eng.connect()

conn.execute("DELETE FROM my_table WHERE col = %s", val)
df.to_sql('my_table', con=eng, if_exists='append')
user1071182
  • 1,609
  • 3
  • 20
  • 28