0

I wrote a package that is using multiprocessing.Pool inside one of its functions.

Due to this reason, it is mandatory (as specified in here under "Safe importing of main module") that the outermost calling function can be imported safely e.g. without starting a new process. This is usually achieved using the if __name__ == "__main__": statement as explicitly explained at the link above.

My understanding (but please correct me if I'm wrong) is that multiprocessing imports the outermost calling module. So, if this is not "import-safe", this will start a new process that will import again the outermost module and so on recursively, until everything crashes.

If the outermost module is not "import-safe" when the main function is launched it usually hangs without printing any warning, error, message, anything.

Since using if __name__ == "__main__": is not usually mandatory and the user is usually not always aware of all the modules used inside a package, I would like to check at the beginning of my function if the user complied with this requirement and, if not, raise a warning/error.

Is this possible? How can I do this?

To show this with an example, consider the following example.

Let's say I developed my_module.py and I share it online/in my company.

# my_module.py
from multiprocessing import Pool

def f(x):
    return x*x

def my_function(x_max):
    with Pool(5) as p:
        print(p.map(f, range(x_max)))

If a user (not me) writes his own script as:

# Script_of_a_good_user.py
from my_module import my_function

if __name__ == '__main__':
    my_function(10)

all is good and the output is printed as expected.

However, if a careless user writes his script as:

# Script_of_a_careless_user.py
from my_module import my_function

my_function(10)

then the process hangs, no output is produces, but no error message or warning is issued to the user.

Is there a way inside my_function, BEFORE opening Pool, to check if the user used the if __name__ == '__main__': condition in its script and, if not, raise an error saying it should do it?

NOTE: I think this behavior is only a problem on Windows machines where fork() is not available, as explained here.

Ted Klein Bergman
  • 9,146
  • 4
  • 29
  • 50
Luca
  • 1,610
  • 1
  • 19
  • 30
  • Just check or print the variable `__name__`. – yoonghm Dec 07 '21 at 12:43
  • 2
    I'm not sure I follow… your file should *define functions*, and any module importing that file can *call* those function if and when needed, in `__main__` or otherwise. If somebody *directly* invokes your module, *you* can decide to call those functions in `__main__`, which only happens if and when somebody explicitly executed your file directly because they presumably wanted it to do something. – deceze Dec 07 '21 at 12:44
  • I hope the edit explains my question – Luca Dec 07 '21 at 13:08
  • 2
    Where's the recursion? If you import `my_module`, then `__name__ == "__main__"` will be *false* when executing the code in `my_module.py`. – chepner Dec 07 '21 at 13:11
  • @deceze Sometimes it is, unfortunately. – Ted Klein Bergman Dec 07 '21 at 13:14
  • Maybe I'm not using the correct wording. This is explained in https://docs.python.org/3/library/multiprocessing.html#multiprocessing-programming in the section "Safe importing of main module". The two examples ARE different if `multiprocessing` is used. Just test them and you see that the first one works, while the second doesn't – Luca Dec 07 '21 at 13:16
  • @deceze it's not my job to protect against other developers writing crap codes. That's true, but not using `if __name__ == "__main__":` it's not by itself a huge error if you do that in the outermost script. However it becames mandatory if any inner function uses `multiprocessing`. Even if it's not my job to fix other people codes, it's usually good practice to raise clear and explanatory exception if the user does an error. Otherwise we would remove all the error messages and we'd replace them with "You did an error, stupid! Find it out yourself" – Luca Dec 07 '21 at 13:19
  • 1
    Voting to reopen, I think this is a fairly clear question. Although you should probably edit some of the clarifications you gave us here in the comments, right into the question. I was not aware the multiprocessing module worked this way, but it makes sense, to have each process launch the script, but only the main process run the spawning code. – joanis Dec 07 '21 at 13:30
  • 1
    You're asking for a lot of code analysis, though, and that's probably going to be very challenging. Inside `my_function`, you could probably inspect the current stack, but the stack won't reflect conditional statements, only modules and functions, as far as I know. – joanis Dec 07 '21 at 13:34
  • 1
    @TedKleinBergman Are you running on windows? Looking at this answer https://stackoverflow.com/a/20361032/4690023 I think the error only occurs on Windows OS – Luca Dec 07 '21 at 13:39
  • 1
    The traceback module could let you inspect the current stack. When you see stack trace dumps, you usually see line numbers and the line of code. I could imagine using that to inspect the code. It would probably fail when the source code was not available, e.g., when you're just running from the `.pyc` file, but this might work to catch most of the programmer errors you're trying to catch. I'm not sure it's worth the work, though. Document your module and provide stern warnings, that's what I would do. – joanis Dec 07 '21 at 13:43
  • @TedKleinBergman as specified in the `multiprocessing` documentation, using it is mandatory if any sub function uses `multiprocessing`. The fact that it is usually non-mandatory it's why I think as the developer of the package it should be on me to check if this is not done and, in case, raise an explanatory error. Clearly if this is not possible or very hard I will revert to a well writte documentation (that 90% of the users won't read anyway) – Luca Dec 07 '21 at 13:53

2 Answers2

1

You can use the traceback module to inspect the stack and find the information you're looking for. Parse the top frame, and look for the main shield in the code.

I assume this will fail when you're working with a .pyc file and don't have access to the source code, but I assume developers will test their code in the regular fashion first before doing any kind of packaging, so I think it's safe to assume your error message will get printed when needed.

Version with verbose messages:

import traceback
import re

def called_from_main_shield():
    print("Calling introspect")
    tb = traceback.extract_stack()
    print(traceback.format_stack())
    print(f"line={tb[0].line} lineno={tb[0].lineno} file={tb[0].filename}")
    try:
        with open(tb[0].filename, mode="rt") as f:
            found_main_shield = False
            for i, line in enumerate(f):
                if re.search(r"__name__.*['\"]__main__['\"]", line):
                    found_main_shield = True
                if i == tb[0].lineno:
                    print(f"found_main_shield={found_main_shield}")
                    return found_main_shield
    except:
        print("Coulnd't inspect stack, let's pretend the code is OK...")
        return True

print(called_from_main_shield())

if __name__ == "__main__":
    print(called_from_main_shield())

In the output, we see that the first called to called_from_main_shield returns False, while the second returns True:

$ python3 introspect.py
Calling introspect
['  File "introspect.py", line 24, in <module>\n    print(called_from_main_shield())\n', '  File "introspect.py", lin
e 7, in called_from_main_shield\n    print(traceback.format_stack())\n']
line=print(called_from_main_shield()) lineno=24 file=introspect.py
found_main_shield=False
False
Calling introspect
['  File "introspect.py", line 27, in <module>\n    print(called_from_main_shield())\n', '  File "introspect.py", lin
e 7, in called_from_main_shield\n    print(traceback.format_stack())\n']
line=print(called_from_main_shield()) lineno=27 file=introspect.py
found_main_shield=True
True

More concise version:

def called_from_main_shield():
    tb = traceback.extract_stack()
    try:
        with open(tb[0].filename, mode="rt") as f:
            found_main_shield = False
            for i, line in enumerate(f):
                if re.search(r"__name__.*['\"]__main__['\"]", line):
                    found_main_shield = True
                if i == tb[0].lineno:
                    return found_main_shield
    except:
        return True

Now, it's not super elegant to use re.search() like I did, but it should be reliable enough. Warning: since I defined this function in my main script, I had to make sure that line didn't match itself, which is why I used ['\"] to match the quotes instead of using a simpler RE like __name__.*__main__. Whatever you chose, just make sure it's flexible enough to match all legal variants of that code, which is what I aimed for.

joanis
  • 10,635
  • 14
  • 30
  • 40
  • I accept this answer as I understand this is the closest thing we'll get to a solution. This is not completely failproof as, for example, when running inside an IDE (eg. Spyder) on windows, the first user module is not `tb[0]`, but could be down the line. I think this is specific to how Spyder has been written and I don't think this should be too relevant to this question. Hence, the answer is, in my opinion, to be accepted. – Luca Dec 07 '21 at 14:33
  • 1
    Interesting point about Spyder, I am not familiar with that IDE. Supporting that would require some more generalization of my code to ignore IDE frames on the stack, I agree. I just tested with `pdb` on the command line, and there too the main stack frame is not `tb[0]`. Room for improvement... – joanis Dec 07 '21 at 14:40
0

I think the best you can do is to try execute the code and provide a hint if it fails. Something like this:

# my_module.py
import sys  # Use sys.stderr to print to the error stream.
from multiprocessing import Pool

def f(x):
    return x*x

def my_function(x_max):
    try:    
        with Pool(5) as p:
            print(p.map(f, range(x_max)))
    except RuntimeError as e:
        print("Whoops! Did you perhaps forget to put the code in `if __name__ == '__main__'`?", file=sys.stderr)
        raise e

This is of course not a 100% solution, as there might be several other reasons the code throws a RuntimeError.


If it doesn't raise a RuntimeError, an ugly solution would be to explicitly force the user to pass in the name of the module.

# my_module.py
from multiprocessing import Pool

def f(x):
    return x*x

def my_function(x_max, module):
    """`module` must be set to `__name__`, for example `my_function(10, __name__)`"""
    if module == '__main__':    
        with Pool(5) as p:
            print(p.map(f, range(x_max)))
    else:
        raise Exception("This can only be called from the main module.")

And call it as:

# Script_of_a_careless_user.py
from my_module import my_function
my_function(10, __name__)

This makes it very explicit to the user.

Ted Klein Bergman
  • 9,146
  • 4
  • 29
  • 50
  • I'm afraid this doesn't work. I just tested it on my WIndows machine and `multiprocessing` hangs without raising any exception. So the `except` clause is never executed. Not even `CTRL+C` is able to stop the process that just hangs – Luca Dec 07 '21 at 14:22
  • Weird, the [documentation](https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods) says it'll raise a `RuntimeError`. – Ted Klein Bergman Dec 07 '21 at 14:24
  • 1
    @Luca Then I think it's just best to be very explicit and go for simplicity. – Ted Klein Bergman Dec 07 '21 at 14:33