2

I tried to patch a resource with a different return value for each thread, but sometimes the patched function becomes unpatched before the other thread can use it.

file1.py:

import subprocess

def dont_run_me() -> str:
    result = subprocess.run('rm -rf /'.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    if result.returncode != 0:
        raise FileNotFoundException(
            f'Something went wrong while deleting everything, I wonder why: {result.stderr.decode()}')
    else:
        return 'Deleted everything'

def function_being_tested(identifier: str) -> str:
    return dont_run_me()

test_file1.py

import unittest

from unittest.mock import patch
from concurrent.futures import ThreadPoolExecutor

from file1 import function_being_tested

class TestFunction(unittest.TestCase):
    def assert_function_not_called(self, thread_name):
        with patch('file1.function_being_tested') as function_being_tested_patch:
            function_being_tested_patch.return_value = thread_name
            result = function_being_tested(thread_name)
            print(result)
            self.assertEqual(t1_response, thread_name)

    def test_function_being_tested(self):
        executor = ThreadPoolExecutor()
        thread_1 = executor.submit(self.assert_function_not_called, 'thread_1')
        thread_2 = executor.submit(self.assert_function_not_called, 'thread_2')

        thread_1.result()
        thread_2.result()

Result when successful:

thread_1
thread_2

Result when failed:

thread_1
Something went wrong while deleting everything, I wonder why: ...

Elsewhere I saw that patch is not thread safe: https://stackoverflow.com/a/26877522/5263074 How can I make sure that each thread can patch a function with a different return value that must be mocked.

yitzchak24
  • 93
  • 1
  • 1
  • 11

1 Answers1

0

Since patch modifies globally patching within a thread won't be thread safe.

Instead of mocking within the context of the thread, patch from before the thread is created. Utilize side_effect to return a different value based on input.

By example:

import unittest

from unittest.mock import patch
from concurrent.futures import ThreadPoolExecutor

from file1 import function_being_tested

class TestFunction(unittest.TestCase):
    def assert_function_not_called(self, thread_name):
        result = function_being_tested(thread_name)
        print(thread_name)
        self.assertEqual(t1_response, thread_name)

    def test_function_being_tested(self):
        def _mocked_dont_run_me(thread_name: str):
            return thread_name

        with patch('file1.dont_run_me', side_effect=_mocked_dont_run_me):
            executor = ThreadPoolExecutor()
            thread_1 = executor.submit(self.assert_function_not_called, 'thread_1')
            thread_2 = executor.submit(self.assert_function_not_called, 'thread_2')

            thread_1.result()
            thread_2.result()

yitzchak24
  • 93
  • 1
  • 1
  • 11