0

Consider two models:

  • Project:
    • id: An AutoField primary key
    • budget: A PositiveIntegerField
    • some unrelated fields
  • Expense
    • id: An AutoField primary key
    • amount: A PositiveIntegerField
    • project: A ForeignKey to Project
    • some unrelated fields

For each project, I wish to ensure that the sum of their expenses is less than or equal to the budget at all times.

In SQL terms: SELECT SUM(amount) FROM expense WHERE project_id = ?; should always be less than or equal SELECT budget FROM project WHERE id = ?;

Is there any way to do this in Django, keeping in mind that multiple people may be accessing the web server and creating Expenses at the same time?

I am using postgresql as my database backend.

I have tried using select_for_update but that doesn't prevent INSERTs from taking place and it doesn't seem to work on aggregations anyway.

I was considering saveing the Expense to the database first, then query the total cost and remove the Expense if it's over-budget but then I would need that code to be outside of a transaction so that other threads can see it and then if the server stops mid-processing, the data could be left in an invalid state.

Matthew
  • 193
  • 6
  • 3
    What did you try with `select_for_update`? Try selecting the `project` instead of the expenses. – Alasdair May 15 '17 at 12:36
  • @Alasdair That... was a lot simpler than I was expecting. I cannot believe I didn't think of that. Anyway, could you please add it as an answer so I can accept it. – Matthew May 15 '17 at 13:10
  • Glad it helped. If you add an answer with some example code, that would probably be more useful to other users in future than if I just copy my comment. – Alasdair May 15 '17 at 13:16

1 Answers1

1

Thanks to @Alasdair for pointing me in the right direction.

After filling out the fields of inst (a new Expense), do:

with transaction.atomic():
    project = models.Project.objects.select_for_update().get(
        pk=project_id)
    cost = project.total_cost()
    budget = project.budget

    if cost + inst.cost > budget:
        raise forms.ValidationError(_('Over-budget'))

    self._inst.save()

Note that I have total_cost defined as a method on Project:

class Project:
    def total_cost(self):
        return self.expense_set.all().aggregate(
            t=Sum(F('cost')))['t']
Matthew
  • 193
  • 6