29

I use Python's unittest module and want to check if two complex data structures are equal. The objects can be lists of dicts with all sorts of values: numbers, strings, Python containers (lists/tuples/dicts) and numpy arrays. The latter are the reason for asking the question, because I cannot just do

self.assertEqual(big_struct1, big_struct2)

because it produces a

ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()

I imagine that I need to write my own equality test for this. It should work for arbitrary structures. My current idea is a recursive function that:

  • tries direct comparison of the current "node" of arg1 to the corresponding node of arg2;
  • if no exception is raised, moves on ("terminal" nodes/leaves are processed here, too);
  • if ValueError is caught, goes deeper until it finds a numpy.array;
  • compares the arrays (e.g. like this).

What seems a little problematic is keeping track of "corresponding" nodes of two structures, but perhaps zip is all I need here.

The question is: are there good (simpler) alternatives to this approach? Maybe numpy presents some tools for this? If no alternatives are suggested, I will implement this idea (unless I have a better one) and post as an answer.

P.S. I have a vague feeling that I might have seen a question addressing this problem, but I can't find it now.

P.P.S. An alternative approach would be a function that traverses the structure and converts all numpy.arrays to lists, but is this any easier to implement? Seems the same to me.


Edit: Subclassing numpy.ndarray sounds very promising, but obviously I don't have both sides of the comparison hard-coded into a test. One of them, though, is indeed hardcoded, so I can:

  • populate it with custom subclasses of numpy.array;
  • change isinstance(other, SaneEqualityArray) to isinstance(other, np.ndarray) in jterrace's answer;
  • always use it as LHS in comparisons.

My questions in this regard are:

  1. Will it work (I mean, it sounds all right to me, but maybe some tricky edge cases will not be handled correctly)? Will my custom object always end up as LHS in the recursive equality checks, as I expect?
  2. Again, are there better ways (given that I get at least one of the structures with real numpy arrays).

Edit 2: I tried it out, the (seemingly) working implementation is shown in this answer.

Community
  • 1
  • 1
Lev Levitsky
  • 63,701
  • 20
  • 147
  • 175
  • I imagine that writing a equality test that works on arbitrary data structures would be arbitrarily hard. Is there really no fixed structure to those? – loopbackbee Jan 09 '13 at 22:27
  • @goncalopp There are several of them, quite convoluted, and in theory subject to changes. I don't want to depend on it, especially since I don't know of a way to compare _everything except `X`_ in two structures even if I know where `X` is. – Lev Levitsky Jan 09 '13 at 22:34
  • Then, personally, I'd go with the recursive function approach. I'd explicitly check the `type` of the object first, though - doing a blind comparison as the first step may be sound, but would be wasteful if your data structures are big, since the values will need to be rechecked if a `ValueError` is raised. – loopbackbee Jan 09 '13 at 22:50
  • @goncalopp Thanks for the input. Performance is not a key issue, this is for testing purposes only. I am more concerned about minimizing the effort needed to implement and maintain the solution. – Lev Levitsky Jan 09 '13 at 23:04

7 Answers7

14

Would have commented, but it gets too long...

Fun fact, you cannot use == to test if arrays are the same I would suggest you use np.testing.assert_array_equal instead.

  1. that checks dtype, shape, etc.,
  2. that doesn't fail for the neat little math of (float('nan') == float('nan')) == False (normal python sequence == has an even more fun way of ignoring this sometimes, because it uses PyObject_RichCompareBool which does a (for NaNs incorrect) is quick check (for testing of course that is perfect)...
  3. There is also assert_allclose because floating point equality can get very tricky if you do actual calculations and you usually want almost the same values, since the values can become hardware depended or possibly random depending what you do with them.

I would almost suggest trying serializing it with pickle if you want something this insanely nested, but that is overly strict (and point 3 is of course fully broken then), for example the memory layout of your array does not matter, but matters to its serialization.

user1248490
  • 963
  • 9
  • 16
seberg
  • 8,785
  • 2
  • 31
  • 30
  • Serializing with `pickle` [has its own problems](http://bugs.python.org/issue6784)... It's good to know about `numpy.testing` utilities, but I'm still not sure how to apply them here. – Lev Levitsky Jan 10 '13 at 07:05
10

The assertEqual function will invoke the __eq__ method of objects, which should recurse for complex data types. The exception is numpy, which doesn't have a sane __eq__ method. Using a numpy subclass from this question, you can restore sanity to the equality behavior:

import copy
import numpy
import unittest

class SaneEqualityArray(numpy.ndarray):
    def __eq__(self, other):
        return (isinstance(other, SaneEqualityArray) and
                self.shape == other.shape and
                numpy.ndarray.__eq__(self, other).all())

class TestAsserts(unittest.TestCase):

    def testAssert(self):
        tests = [
            [1, 2],
            {'foo': 2},
            [2, 'foo', {'d': 4}],
            SaneEqualityArray([1, 2]),
            {'foo': {'hey': SaneEqualityArray([2, 3])}},
            [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
             SaneEqualityArray([5, 6]), 34]
        ]
        for t in tests:
            self.assertEqual(t, copy.deepcopy(t))

if __name__ == '__main__':
    unittest.main()

This test passes.

Community
  • 1
  • 1
jterrace
  • 64,866
  • 22
  • 157
  • 202
  • Instead of messing with `__eq__` I would *much* rather mess with `__nonzero__`. though there are good reasons why numpy does it as it does it, this just invites errors. – seberg Jan 10 '13 at 02:16
  • Thanks so much for the link (now it seems like I'd seen it before)! But I'm testing a function that returns stuff with `numpy.array`, not its custom subclasses. How would I do the conversion before assertion? (or otherwise I need to remove `isinstance(other, SaneEqualityArray)` from `__eq__`, is that a good idea?) – Lev Levitsky Jan 10 '13 at 07:10
  • Also @seberg can you please elaborate on your concern about overriding `__eq__`? (and maybe spell out the suggestion about `__nonzero__` for me: I take it needs to be called on `A-B`? So `A-B` should give an instance of the subclass? Doesn't it mean I also need to override `__sub__`?) – Lev Levitsky Jan 10 '13 at 07:17
  • @LevLevitsky what causes the errors is not `__eq__`, what causes the errors and is ill defined is `__nonzero__` (ie. `bool(np.ndarray)`), changing `__nonzero__` should probably not change a working program unless it relies on errors being thrown. Looks like a big advantage to me... – seberg Jan 10 '13 at 09:36
  • @seberg Thanks, but I wasn't planning on changing anything in the functions I test. I still want them to put `numpy.array` in the return values, not custom subclasses of it. This is in reply to your concern about "changing a working program". – Lev Levitsky Jan 10 '13 at 10:10
  • By the way, the argument you pass to the constructor of `SaneEqualityArray` is actually interpreted as its shape, and the contents are not initialized. I made this mistake, too: `ndarray` can't be created with `array`'s signature. – Lev Levitsky Jan 10 '13 at 21:49
  • hmm, yeah maybe not a great solution then – jterrace Jan 10 '13 at 22:00
  • Actually, it seems to work (except that I do need to tweak it to use `almost_equal` somehow). The problem is in the demo test, not the class definition. I'll let you know when I get it to work. – Lev Levitsky Jan 10 '13 at 22:37
  • I added the modified code as a separate answer, it also shows the correct (or at least working) way to create instances. – Lev Levitsky Jan 11 '13 at 11:10
7

So the idea illustrated by jterrace seems to work for me with a slight modification:

class SaneEqualityArray(np.ndarray):
    def __eq__(self, other):
        return (isinstance(other, np.ndarray) and self.shape == other.shape and 
            np.allclose(self, other))

Like I said, the container with these objects should be on the left side of the equality check. I create SaneEqualityArray objects from existing numpy.ndarrays like this:

SaneEqualityArray(my_array.shape, my_array.dtype, my_array)

in accordance with ndarray constructor signature:

ndarray(shape, dtype=float, buffer=None, offset=0,
        strides=None, order=None)

This class is defined within the test suite and serves for testing purposes only. The RHS of the equality check is an actual object returned by the tested function and contains real numpy.ndarray objects.

P.S. Thanks to the authors of both answers posted so far, they were both very helpful. If anyone sees any problems with this approach, I'd appreciate your feedback.

Community
  • 1
  • 1
Lev Levitsky
  • 63,701
  • 20
  • 147
  • 175
2

I would define my own assertNumpyArraysEqual() method that explicitly makes the comparison that you want to use. That way, your production code is unchanged but you can still make reasonable assertions in your unit tests. Make sure to define it in a module that includes __unittest = True so that it will not be included in stack traces:

import numpy
__unittest = True

def assertNumpyArraysEqual(self, other):
    if self.shape != other.shape:
        raise AssertionError("Shapes don't match")
    if not numpy.allclose(self, other)
        raise AssertionError("Elements don't match!")
dbn
  • 13,144
  • 3
  • 60
  • 86
  • Thanks, it's a good idea and would be the best option in case if both arrays are generated by function(s) being tested. – Lev Levitsky Mar 14 '13 at 09:42
1

check numpy.testing.assert_almost_equal which "raises an AssertionError if two items are not equal up to desired precision", e.g.:

 import numpy.testing as npt
 npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
                         np.array([1.0,2.33333334]), decimal=9)
Hanan Shteingart
  • 8,480
  • 10
  • 53
  • 66
1

I've run into the same issue, and developed a function to compare equality based on creating a fixed hash for the object. This has the added advantage that you can test that an object is as expected by comparing it's hash against a fixed has shored in code.

The code (a stand-alone python file, is here). There are two functions: fixed_hash_eq, which solves your problem, and compute_fixed_hash, which makes a hash from the structure. Tests are here

Here's a test:

obj1 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj2 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3[2]['b'][4] = 0
assert fixed_hash_eq(obj1, obj2)
assert not fixed_hash_eq(obj1, obj3)
Peter
  • 12,274
  • 9
  • 71
  • 86
0

Building on @dbw (with thanks), the following method inserted within the test-case subclass worked well for me:

 def assertNumpyArraysEqual(self,this,that,msg=''):
    '''
    modified from http://stackoverflow.com/a/15399475/5459638
    '''
    if this.shape != that.shape:
        raise AssertionError("Shapes don't match")
    if not np.allclose(this,that):
        raise AssertionError("Elements don't match!")

I had it called as self.assertNumpyArraysEqual(this,that) inside my test case methods and worked like a charm.

XavierStuvw
  • 1,294
  • 2
  • 15
  • 30