0

I am developing a Django backend for an online course platform. When a student submits code for grading, the backend runs unit tests (using the unittest library) and returns to the frontend the message returned by the failed test.

After much help from ChatGPT, I can catch the message passed to the assertions by using the following wrapper:

def save_custom_message(func):
    def wrapper(self, *args, **kwargs):
        MyTests.msg = kwargs.get('msg', '')
        return func(self, *args, **kwargs)            
    return wrapper

class CustomTestCase(unittest.TestCase):
    def __getattribute__(self, name):
        attr = super().__getattribute__(name)
        if name.startswith("assert") and callable(attr):
            return save_custom_message(attr)
        return attr

Here is the usage:

class MyTests(CustomTestCase):
    def test_division(self):
        self.assertEqual(10 / 2, 5, msg="First division assert")
        self.assertEqual(10 / 2, 3, msg="Second division assert")
        self.assertEqual(10 / 5, 2, msg="Third division assert")

def run_unit_tests(test_class):
    loader = unittest.TestLoader()
    suite = loader.loadTestsFromTestCase(test_class)

    for test in suite:
        result = unittest.TextTestRunner().run(test)
        if not result.wasSuccessful():
            for failure in result.failures:
                return MyTests.msg
    return True

result = run_unit_tests(MyTests)
if result is True:
    print("All tests passed!")
else:
    print(f"Test failed: {result}")

This works just fine. But now I would like to be able to pass a message to the test function:

class MyTests(CustomTestCase):
    def test_division(self, msg="Something is wrong with division"):
        self.assertEqual(10 / 2, 5)
        self.assertEqual(10 / 2, 3)
        self.assertEqual(10 / 5, 2)

My understanding was that test_division was a method of MyTests just like assertEqual. Therefore, I tried simply testing for the prefix test:

class CustomTestCase(unittest.TestCase):
    default_message = "The test case failed"
    def __getattribute__(self, name):
        attr = super().__getattribute__(name)
        if (name.startswith("assert") or name.startswith("test")) and \
           callable(attr):
            return save_custom_message(attr)
        return attr

However, I got the error:

TypeError: save_custom_message.<locals>.wrapper() missing 1 required positional argument: 'self'

This suggests that my understanding is wrong. I would very much appreciate an in-depth explanation of the issue and a suggestion on how to fix it.

P.S. An additional confusion I would be happy to resolve is that, when I tried printing args for `self.assertEqual(10 / 2, 3, msg="Second division assert"), only 3 would be printed, but 5 would not. This suggest that the first argument does not seem to be passed. I would be happy to understand what's going on.

AlwaysLearning
  • 7,257
  • 4
  • 33
  • 68

1 Answers1

0

I have researched the topic. As explained in this reply, the main problem with my approach is that the default keyword arguments are not passed into the wrapper.

This works:

import unittest
from functools import wraps
from inspect import signature

def save_custom_message(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        MyTests.msg = signature(func).parameters['msg'].default
        return func(*args, **kwargs)            
    return wrapper

class CustomTestCase(unittest.TestCase):
    def __getattribute__(self, name):
        attr = object.__getattribute__(self, name)
        if name.startswith("test") and callable(attr):
            return save_custom_message(attr)
        return attr

class MyTests(CustomTestCase):
    def test_division(self, msg = 'my message'):
        self.assertEqual(10 / 2, 5)
        self.assertEqual(10 / 2, 3)
        self.assertEqual(10 / 5, 2)

def run_unit_tests(test_class):
    loader = unittest.TestLoader()
    suite = loader.loadTestsFromTestCase(test_class)

    for test in suite:
        result = unittest.TextTestRunner().run(test)
        if not result.wasSuccessful():
            for failure in result.failures:
                return MyTests.msg
    return True

result = run_unit_tests(MyTests)
if result is True:
    print("All tests passed!")
else:
    print(f"Test failed: {result}")
AlwaysLearning
  • 7,257
  • 4
  • 33
  • 68