1
import tensorflow as tf
import numpy as np
x=np.random.rand(20,10,64)
y=np.random.randint(10,size=(20,1))
class mymodel(tf.keras.Model):
  def __init__(self):
    super(mymodel,self).__init__()
    self.l1 = tf.keras.layers.LSTM(10,return_state=True)
    self.l2 = tf.keras.layers.Dense(10,activation=tf.keras.activations.softmax)
  def call(self,input):
    print('hi')
    x=self.l1(input)
    # tf.print(x[0],x[1],x[2])
    x=self.l2(x[0])
    return x
model =mymodel()
model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy())
model.fit(x,y)

When I run the code above I get 2 hi printed out.There is only 1 epoch and within that epoch 1 batch(batch size being 20) in this example then why is call method getting called twice.

Abhishek Kishore
  • 340
  • 2
  • 13

1 Answers1

2

Related: tf.print() vs Python print vs tensor.eval()

The Python print inserted inside the call function will only be executed the first time, when the graph is built underneath. The second call is triggered by the TF eager execution, again only the first time (if you run tf.compat.v1.disable_eager_execution() before model.fit you will see only one printed hi).

However, if you run model.fit a second time (and a third, ...) you will notice that nothing gets printed. This is because, once the graph has been built, the forward pass does not execute the call function anymore. If you want to print something related to each execution of the forward pass, you should use instead tf.print("hi"). You will notice that, this way, one and only one print happens each eopch in model.fit.

ibarrond
  • 6,617
  • 4
  • 26
  • 45