I would like to auto-differentiate across a rather complex function that I wish to parallelize.
I am using TensorFlow 2.x and using tf.GradientTape for differentiation.
I have made a toy example that illustrates the point. The auto-differentiation works just fine without threading but breaks when the exact same calculation is run in two separate threads.
import pdb
import tensorflow as tf
import threading
# This ThreadWithResult is from https://stackoverflow.com/a/65447493/1935801 and works fine on its own
class ThreadWithResult(threading.Thread):
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None):
def function():
self.result = target(*args, **kwargs)
super().__init__(group=group, target=function, name=name, daemon=daemon)
def my_function(x):
return x*x + x*x*x
def my_function_threaded(x):
def square(x):
result = x*x
return result
def cube(x):
result = x*x*x
return result
t1 = ThreadWithResult(target=square, args=(x,))
t2 = ThreadWithResult(target=cube, args=(x,))
t1.start()
t2.start()
t1.join()
t2.join()
y = t1.result + t2.result
return y
x = tf.constant(3.0)
print("my_function(x) =", my_function(x))
print("my_function_threaded(x) =", my_function_threaded(x))
with tf.GradientTape() as tape:
tape.watch(x)
y = my_function(x)
dy_dx = tape.gradient(y, x, unconnected_gradients=tf.UnconnectedGradients.ZERO)
print("Simple dy_dx", dy_dx)
with tf.GradientTape() as tape:
tape.watch(x)
y = my_function_threaded(x)
dy_dx = tape.gradient(y, x, unconnected_gradients=tf.UnconnectedGradients.ZERO)
print("Threaded dy_dx", dy_dx)
As one can see in the output shown below, gradients are broken when threading is used for the same simple calculation.
my_function(x) = tf.Tensor(36.0, shape=(), dtype=float32)
my_function_threaded(x) = tf.Tensor(36.0, shape=(), dtype=float32)
Simple dy_dx tf.Tensor(33.0, shape=(), dtype=float32)
Threaded dy_dx tf.Tensor(0.0, shape=(), dtype=float32)
Any suggestions/ideas on how I could paralelize my function within GradientTape wouold be much appreciated?