I was reading about race conditions in Python Multi-Threading and came across a code snippets illustrating the same.
When I tried executing them, the behavior (ie output value) changes as a function of time.
I ran the code in Jupyter Notebook. This was from a tutorial in PyCon 2020.
Code Snippet (I am putting them in separate blocks as an equivalent to me running them in different cells of the notebook):
import threading, time, random
COUNTER = 0
def increement(n):
global COUNTER # Use the global variable declared above
for _ in range(n):
COUNTER += 1
time.sleep(0.001)
print( f"{threading.current_thread().name} finished counting")
ITERATIONS = 10000 # set number of iterations to increase counter
# create 10 threads each set to run the increement() function
threads = [threading.Thread(target=increement, args=(ITERATIONS,) ) for _ in range(10) ]
[t.start() for t in threads]
assert COUNTER == (len(threads) * ITERATIONS), f"Invalid value for counter: {COUNTER}, expected value: {len(threads) * ITERATIONS} "
COUNTER
's value for a 5 sec vs 10 sec gap between running the last 2 cells seems to be 100,000 vs 200,000. In the last cell, even the small timegap between running the print
function vs assert
statement results in a change in value.
My guess is that the threads executing the increment()
take longer than the other threads (executing the notebook cells) to run. But I'd love to hear the community's explanation & perspective (and also any advice for me).
As a follow up question, if I were to execute a similar code inside a script, how do I ensure that the threads fully finish executing before moving on to further process the COUNTER
variable? I am aware of threading.Thread.join()
but when I invoked it after starting the thread, all of them executed sequentially (like the below code snippet). What I was wondering is whether there is a way to wait till all the threads finish executing concurrently and then process COUNTER
?
# Using join() this way results in all threads getting sequentially executed
for t in threads:
t.start()
t.join()
Thanks in advance!