0

For some reason I am getting different tensor dimensions when using gather in TF 2:

  1. The first dimension becomes None when I use tensor as an index vector
  2. The first dimension becomes len(indices) (as it should) where 'indices' are regular Python list

This happens only in eager mode (e.g., inside a custom loss function)

(Same happens when using boolean_mask)

EDIT: The following code reproduces the problem with TF 2.7.0 and Python 3.8.10

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Dense, Reshape
from tensorflow.keras.datasets import mnist

def cutsom_gan_loss_env(model):
   def custom_loss(y_true,y_pred):

    ff = tf.where([True, True, False , False])[:, 0]
    with tf.GradientTape(persistent=True) as tape:
         tf.print(tf.gather(y_true, [0, 1], axis=0).shape)
         tf.print(tf.gather(y_true, ff, axis=0).shape)
         tape.watch(y_true)
         yy = model(y_true)
         d_yy = tape.gradient(yy,y_true)
         des_loss = tf.reduce_mean(d_yy)

    return des_loss

return custom_loss


def main_():
   n_hidden_units = 5
   num_lay = 3
   kernel_init = keras.initializers.RandomUniform(-0.1, 0.1)
   (x_train, y_train), _ = mnist.load_data()
   x_train = tf.cast(x_train,tf.float32)/255.
   inputs = Input(x_train.shape[1:])
   x = Dense(n_hidden_units,kernel_initializer=kernel_init,  activation='sigmoid' )(inputs)
   for _ in range(num_lay):
       x = Dense(n_hidden_units,kernel_initializer=kernel_init, activation='sigmoid', )(x)

   outputs =Reshape(x_train.shape[1:])(Dense(x_train.shape[1], kernel_initializer=kernel_init, activation='softmax')(x))
   model = Model(inputs=inputs, outputs=outputs)
   model.summary()
   optimizer1 = keras.optimizers.Adam(beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=True)
   model.compile(loss=cutsom_gan_loss_env(model), optimizer=optimizer1, metrics=None)
   model.fit(x_train,  x_train , batch_size=1000, epochs=1, shuffle=False)


if __name__=='__main__':
    main_()
Benny K
  • 1,957
  • 18
  • 33

1 Answers1

1

This is not an error, but rather the difference between tensor.shape and tf.shape. The latter will give you the dynamic shape of a tensor after an operation like tf.gather.

Change:

tf.print(tf.gather(y_true, [0, 1], axis=0).shape)
tf.print(tf.gather(y_true, ff, axis=0).shape)

To:

tf.print(tf.shape(tf.gather(y_true, [0, 1], axis=0)))
tf.print(tf.shape(tf.gather(y_true, ff, axis=0)))

And the tensors will be evaluated correctly during model.fit using tf.shape. Also read this post for a better understanding.

AloneTogether
  • 25,814
  • 5
  • 20
  • 39