0

So, suppose there is this models:


class Country(Base):
    __tablename__ = "countries"

    id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
    code = Column(String, index=True, nullable=False, unique=True)
    name = Column(String, nullable=False)


class EventSource(Base):
    __tablename__ = "eventsources"

    id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
    name = Column(String, nullable=False, unique=True)


class Event(Base):
    __tablename__ = "events"

    id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
    country_id = Column(
        UUID(as_uuid=True),
        ForeignKey("countries.id", ondelete="CASCADE"),
        nullable=False,
    )
    eventsource_id = Column(
        UUID(as_uuid=True),
        ForeignKey("eventsources.id", ondelete="CASCADE"),
        nullable=False,
    )
    created_at = Column(DateTime(timezone=True), default=func.now())

And the task is to get counts of events for each country by eventsource. This is pretty easily done in raw SQL:

SELECT eventsource.name, countries.code, COUNT(events.id) as events_count 
FROM events 
  JOIN countries ON events.country_id = countries.id 
  JOIN eventsources ON events.eventsource_id = eventsources.id 
GROUP BY eventsources.name, countries.code;

So the result is that for each eventsource we have a count of events grouped by countries. Now the question is, how to properly setup models and make query in sqlalchemy (preferably in 2.0 style syntax) so that end result looks like a list of eventsource models where countries are relationships with aggregated events count, that can be accessed like in the last line of this next code block:

# Initialize our database:
country_en = Country(code="en", name="England")
country_de = Country(code="de", name="Germany")
eventsource_tv = EventSource(name="TV")
eventsource_internet = EventSource(name="Internet")
session.add(country_en)
session.add(country_de)
session.add(eventsource_tv)
session.add(eventsource_internet)
session.flush()
session.add(Event(country_id=country_en.id, eventsource_id = eventsource_tv.id)
session.add(Event(country_id=country_en.id, eventsource_id = eventsource_tv.id)
session.add(Event(country_id=country_en.id, eventsource_id = eventsource_tv.id)
session.add(Event(country_id=country_de.id, eventsource_id = eventsource_tv.id)
session.add(Event(country_id=country_en.id, eventsource_id = eventsource_internet.id)
session.add(Event(country_id=country_en.id, eventsource_id = eventsource_internet.id)
session.flush()
# Aggregate eventsources somehow:
eventsources = session.execute(select(EventSource).order_by(EventSource.name).all() # this is the line where some magick that solves problem should happen 
# Print results. This line should output "2" (eventsource "Internet" for country "en"):
print(eventsources[0].countries[0].events_count)
Sid
  • 263
  • 2
  • 9

1 Answers1

0

For thouse who encounters the same problem, this is what I end up doing. Here is an example of relationship on a target that is a select query. My solution was to create query, and then map results to a custom class, based on the link above. This is roughly what I've done (not exactly the code that I run, but something pretty similar):


class Country(Base):
    __tablename__ = "countries"

    id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
    code = Column(String, index=True, nullable=False, unique=True)
    name = Column(String, nullable=False)


class EventSource(Base):
    __tablename__ = "eventsources"

    id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
    name = Column(String, nullable=False, unique=True)


class Event(Base):
    __tablename__ = "events"

    id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
    country_id = Column(
        UUID(as_uuid=True),
        ForeignKey("countries.id", ondelete="CASCADE"),
        nullable=False,
    )
    eventsource_id = Column(
        UUID(as_uuid=True),
        ForeignKey("eventsources.id", ondelete="CASCADE"),
        nullable=False,
    )
    created_at = Column(DateTime(timezone=True), default=func.now())
    country = relationship("Country")


@dataclass
class CountriesCount:
    eventsource_id: UUID
    country_code: str
    events_count: int


events_counts_table = (
    select(
        Event.eventsource_id.label("eventsource_id"),
        Country.code.label("country_code"),
        func.count(Country.code).label("events_count"),
    )
    .select_from(Event)
    .join(Country, Event.country)
    .group_by(Event.eventsource_id, Country.code)
).alias()


EventSource.countries = relationship(
    registry().map_imperatively(
        CountriesCount,
        events_counts_table,
        primary_key=[
            events_counts_table.c.eventsource_id,
            events_counts_table.c.country_code,
        ],
    ),
    viewonly=True,
    primaryjoin=EventSource.id == events_counts_table.c.eventsource_id,
)


Sid
  • 263
  • 2
  • 9