31

I have the following statement in one of the methods under unit test.

db_employees = self.db._session.query(Employee).filter(Employee.dept ==   
    new_employee.dept).all()

I want db_employees to get mock list of employees. I tried to achieve this using:

 m = MagickMock()
 m.return_value.filter().all().return_value = employees

where employees is a list of employee object. But this did not work. When I try to print the value of any attribute, it has a mock value. This is how the code looks:

class Database(object):
    def __init__(self, user=None, passwd=None, db="sqlite:////tmp/emp.db"):
        try:
            engine = create_engine(db)
        except Exception:
            raise ValueError("Database '%s' does not exist." % db)

        def on_connect(conn, record):
            conn.execute('pragma foreign_keys=ON')

        if 'sqlite://' in db:
            event.listen(engine, 'connect', on_connect)
        Base.metadata.bind = engine
        DBSession = sessionmaker(bind=engine)
        self._session = DBSession()


class TestEmployee(MyEmployee):
    def setUp(self):
        self.db = emp.database.Database(db=options.connection)
        self.db._session._autoflush()

    @mock.patch.object(session.Session, 'add')     
    @mock.patch.object(session.Session, 'query')  
    def test_update(self, mock_query, mock_add):
        employees = [{'id': 1,
                      'name': 'Pradeep',
                      'department': 'IT',
                      'manager': 'John'}]
        mock_add.side_effect = self.add_side_effect
        mock_query.return_value = self.query_results()  
        self.update_employees(employees)

    def add_side_effect(self, instance, _warn=True):
        // Code to mock add
        // Values will be stored in a dict which will be used to 
        // check with expected value.

    def query_results(self):  
        m = MagicMock()  
        if self.count == 0:  
             m.return_value.filter.return_value.all.return_value = [employee]  
        elif:  
             m.return_value.filter.return_value.all.return_value = [department]  
        return m  

I have query_results as the method under test calls query twice. First the employee table and next the department table.

How do I mock this chained function call?

Pradeep
  • 1,198
  • 3
  • 12
  • 22
  • First of all replace `m.return_value.filter...` by `m.filter...` because `m` is already set as `query` return value. And the add to your question that `mock_query.mock_calls` is empty. As far as I can understand you are patching the wrong object. – Michele d'Amico Dec 16 '15 at 12:53
  • Now try to use `@mock.patch(__name__ + '.DBSession.query')` to patch query. – Michele d'Amico Dec 16 '15 at 13:29

3 Answers3

42
m = Mock()
m.session.query().filter().all.return_value = employees

https://docs.python.org/3/library/unittest.mock.html

Mark Amery
  • 143,130
  • 81
  • 406
  • 459
Tate Thurston
  • 4,236
  • 1
  • 26
  • 22
  • 1
    Ah, this is brilliant -- we had been creating the results of chained methods manually, eg. `obj = Mock(); Object = Mock(); Object.return_value = obj` but this stopped working once there were was more than 1 level of chaining, and the `package.one().two().three().return_value = ...` is all that worked – kevlarr Oct 07 '19 at 14:14
  • 1
    How do you get around `Expected 'filter' to have been called once. Called 2 times.` when doing something like `m.session.query.assert_called_once()`? – themanatuf Jan 19 '20 at 00:38
  • This makes your test dependent on implementation details and quite brittle... is there really not a better way? – ronathan Jan 22 '21 at 10:07
  • This will add to the call stack for the mock. You can get around that by using the return_values instead: `m.session.query.return_value.filter.return_value.all.return_value = employees` – Spencer Nov 05 '22 at 00:02
8

I found a solution to a similar problem where I needed to mock out a nested set of filtering calls.

Given code under test similar to the following:

interesting_cats = (session.query(Cats)
                           .filter(Cat.fur_type == 'furry')
                           .filter(Cat.voice == 'meowrific')
                           .filter(Cat.color == 'orande')
                           .all())

You can setup mocks like the following:

mock_session_response = MagicMock()
# This is the magic - create a mock loop
mock_session_response.filter.return_value = mock_session_response
# We can exit the loop with a call to 'all'
mock_session_response.all.return_value = provided_cats

mock_session = MagicMock(spec=Session)
mock_session.query.return_value = mock_session_response
0

You should patch query() method of _session's Database attribute and configure it to give you the right answer. You can do it in a lot of way, but IMHO the cleaner way is to patch DBSession's query static reference. I don't know from witch module you imported DBSession so I'll patch the local reference.

The other aspect is the mock configuration: we will set query's return value that in your case become the object that have filter() method.

class TestEmployee(MyEmployee):
    def setUp(self):
        self.db = emp.database.Database(db=options.connection)
        self.db._session._autoflush()
        self.log_add = {}

    @mock.patch.object(__name__.'DBSession.add')     
    @mock.patch.object(__name__.'DBSession.query')  
    def test_update(self, mock_query, mock_add):
        employees = [{'id': 1,
                      'name': 'Pradeep',
                      'department': 'IT',
                      'manager': 'John'}]
        mock_add.side_effect = self.add_side_effect
        mock_query.return_value = self.query_results()  
        self.update_employees(employees)
        .... your test here

    def add_side_effect(self, instance, _warn=True):
        # ... storing data
        self.log_add[...] = [...]

    def query_results(self):  
        m = MagicMock()
        value = "[department]"
        if not self.count:  
             value = "[employee]"  
        m.filter.return_value.all.return_value = value 
        return m
Michele d'Amico
  • 22,111
  • 8
  • 69
  • 76
  • It worked. Changed mock_query.return_value to mock_query.side_effect = self.query_results. Now I am able to get the objects as expected. I am accepting your answer, as it helped me to resolve the chained function call. – Pradeep Dec 17 '15 at 07:49