I am puzzled by the behavior I observe in the following example:
import tensorflow as tf
@tf.function
def f(a):
c = a * 2
b = tf.reduce_sum(c ** 2 + 2 * c)
return b, c
def fplain(a):
c = a * 2
b = tf.reduce_sum(c ** 2 + 2 * c)
return b, c
a = tf.Variable([[0., 1.], [1., 0.]])
with tf.GradientTape() as tape:
b, c = f(a)
print('tf.function gradient: ', tape.gradient([b], [c]))
# outputs: tf.function gradient: [None]
with tf.GradientTape() as tape:
b, c = fplain(a)
print('plain gradient: ', tape.gradient([b], [c]))
# outputs: plain gradient: [<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
# array([[2., 6.],
# [6., 2.]], dtype=float32)>]
The lower behavior is what I would expect. How can I understand the @tf.function case?
Thank you very much in advance!
(Note that this problem is distinct from: Missing gradient when using tf.function , since here all calculations are inside the function.)